CA2 Part A GAN Analysis

Jayden Yap Jean Hng p2112790

Introduction

Generative Adversarial Networks (GANs) were first introduced in 2014 by Ian Goodfellow. It consists of 2 neural networks: a generator network to create new data samples, and a discriminator network that tries to distinguish the generated samples from real ones. The two networks are trained together in an adversarial way, where the generator tries to create samples that fool the discriminator to think it is real, and the discriminator tries to correctly identify the generated samples.

GANs have been applied to a wide range of tasks. One popular dataset for training GANs is the CIFAR-10 dataset, which consists of 60,000 32x32 color images in 10 classes, with 6,000 images per class. There are 50,000 training images and 10,000 test images.

However, training GANs is a challenging task, as the generator and discriminator can easily get stuck, such as mode collapse (where the generator produces limited variations of the same sample) or overfitting (where the generator produces samples that are too similar to the training data). Researchers have proposed various techniques to stabilize GAN training, such as using different loss functions or architectures, but the problem is still an active area of research.

For my project, i want to experiment with different architectures, loss functions, and hyperparameters to train a GAN on the CIFAR-10 dataset, and compare its performance to the state-of-the-art models via FID metrics.

In [1]:
!nvidia-smi
Sun Feb  5 06:57:36 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.60.13    Driver Version: 525.60.13    CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  NVIDIA GeForce ...  On   | 00000000:0A:00.0 Off |                  N/A |
|  0%   53C    P8    16W / 240W |      1MiB /  8192MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|  No running processes found                                                 |
+-----------------------------------------------------------------------------+
In [2]:
import tensorflow as tf
from tensorflow.keras import backend as K
import numpy as np
from matplotlib import pyplot as plt
import math
In [3]:
#using school server 8 GPUS, select one gpu
import os
os.environ["CUDA_VISIBLE_DEVICES"]="7"
tf.config.list_physical_devices(
    device_type=None
)
Out[3]:
[PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU'),
 PhysicalDevice(name='/physical_device:XLA_CPU:0', device_type='XLA_CPU'),
 PhysicalDevice(name='/physical_device:XLA_GPU:0', device_type='XLA_GPU'),
 PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]
In [4]:
physical_devices = tf.config.experimental.list_physical_devices('GPU')
for i in physical_devices:
    tf.config.experimental.set_memory_growth(i, True)
In [5]:
#hide warning for better outputs
import warnings
warnings.filterwarnings("ignore")
In [6]:
%%javascript
(function(on) {
    const e = $("<a>Setup failed</a>");
    const ns = "js_jupyter_suppress_warnings";
    var cssrules = $("#" + ns);
    if(!cssrules.length)
        cssrules = $("<style id='" + ns + "' type='text/css'>div.output_stderr { } </style>").appendTo("head");
    e.click(function() {
        var s = 'Showing';
        cssrules.empty()
        if(on) {
            s = 'Hiding';
            cssrules.append("div.output_stderr, div[data-mime-type*='.stderr'] { display:none; }");
        }
        e.text(s + ' warnings (click to toggle)');
        on = !on;
    }).click();
    $(element).append(e);
})(true);

Common functions used for data loading/model training

In [7]:
# load and prepare cifar10 training images
def loadRealSamples():
    # load cifar10 dataset
    (trainX, _), (testX, _) = tf.keras.datasets.cifar10.load_data()
    # convert from unsigned ints to floats
    X = trainX.astype('float32')
    X2=testX.astype('float32')
    X=np.concatenate((X,X2), axis=0)
    # scale from [0,255] to [
    # -1,1]
    X = (X - 127.5) / 127.5
    return X

def loadRealLabels():
    # load cifar10 dataset
    ( _,trainY), (_,testY) = tf.keras.datasets.cifar10.load_data()
    Y=np.concatenate((trainY,testY), axis=0)
    Y=tf.keras.utils.to_categorical(Y, 10)
    return Y

#load and prepare only test images
def loadTestSamples():
    # load cifar10 dataset
    (_, _), (testX, _) = tf.keras.datasets.cifar10.load_data()
    # convert from unsigned ints to floats
    X=testX.astype('float32')
    # scale from [0,255] to [
    # -1,1]
    X = (X - 127.5) / 127.5
    return X

# select some real image samples
def selectRealSamples(dataset, n_samples):
    # choose random instances
    ix = np.random.randint(0, dataset.shape[0], n_samples)
    # retrieve selected images
    X = dataset[ix]
    # generate 'real' class labels (1)
    y = np.ones((n_samples, 1))
    return X, y

# generate points in latent space as input for the generator
def generateLatentPoints(latentDim, n_samples):
    # generate points in the latent space
    x_input = np.random.randn(latentDim * n_samples)
    # reshape into a batch of inputs for the network
    x_input = x_input.reshape(n_samples, latentDim)
    return x_input


# use the generator to generate n fake examples, with class labels
def generateFakeSamples(genModel, latentDim, n_samples):
    # generate points in latent space
    x_input = generateLatentPoints(latentDim, n_samples)
    # predict outputs
    X = genModel.predict(x_input)
    # create 'fake' class labels (0)
    y = np.zeros((n_samples, 1))
    return X, y

# create and save a plot of generated images
def savePlot(examples, epoch, name, n=5,save_frequency=10):
    # scale from [-1,1] to [0,1]
    examples = (examples + 1) / 2.0
    plt.figure(figsize=(5,5))
    plt.suptitle(f'{name} epoch {epoch+1}')
    # plot images
    for i in range(n * n):
        # define subplot
        plt.subplot(n, n, 1 + i)
        # turn off axis
        plt.axis('off')
        # plot raw pixel data
        
        plt.imshow(examples[i])
    # save plot to file every 10 epochs
    if (epoch+1)%save_frequency==0:
        subfolder_path = os.path.join(os.getcwd(), 'imgs')
        subfolder_path = os.path.join(subfolder_path, f'{name}')
        if not os.path.exists(subfolder_path):
            os.makedirs(subfolder_path)
        filename = os.path.join(subfolder_path, f'plot{epoch+1}.png')
        plt.savefig(filename[:])
    plt.show()

    
# evaluate the discriminator, plot generated images, save generator model
def summarisePerf(epoch, genModel, discModel, dataset, latentDim, n_samples=150, name='model'):
    # prepare real samples
    X_real, y_real = selectRealSamples(dataset, n_samples)
    # evaluate discriminator on real examples
    _, acc_real = discModel.evaluate(X_real, y_real, verbose=0)
    # prepare fake examples
    x_fake, y_fake = generateFakeSamples(genModel, latentDim, n_samples)
    # evaluate discriminator on fake examples
    _, acc_fake = discModel.evaluate(x_fake, y_fake, verbose=0)
    # summarize discriminator performance
    print('>Accuracy real: %.0f%%, fake: %.0f%%' % (acc_real * 100, acc_fake * 100))
    # save plot
    savePlot(x_fake, epoch,name)
        

    

FID Class/Functions

Frechet Inception Distance: Evaluates distribution of generated images by comparing with real ground truth images.

required libraries in addition to torch, restart kernel after running

In [8]:
# !pip install torchmetrics
# !pip install torch-fidelity
# !pip install --upgrade typing-extensions
# !pip install torchvision
Collecting torchmetrics
  Using cached torchmetrics-0.11.1-py3-none-any.whl (517 kB)
Requirement already satisfied: packaging in /opt/conda/lib/python3.8/site-packages (from torchmetrics) (20.4)
Requirement already satisfied: numpy>=1.17.2 in /opt/conda/lib/python3.8/site-packages (from torchmetrics) (1.19.1)
Collecting torch>=1.8.1
  Using cached torch-1.13.1-cp38-cp38-manylinux1_x86_64.whl (887.4 MB)
Requirement already satisfied: typing-extensions; python_version < "3.9" in /opt/conda/lib/python3.8/site-packages (from torchmetrics) (3.7.4.2)
Requirement already satisfied: pyparsing>=2.0.2 in /opt/conda/lib/python3.8/site-packages (from packaging->torchmetrics) (2.4.7)
Requirement already satisfied: six in /opt/conda/lib/python3.8/site-packages (from packaging->torchmetrics) (1.15.0)
Collecting nvidia-cublas-cu11==11.10.3.66; platform_system == "Linux"
  Using cached nvidia_cublas_cu11-11.10.3.66-py3-none-manylinux1_x86_64.whl (317.1 MB)
Collecting nvidia-cuda-runtime-cu11==11.7.99; platform_system == "Linux"
  Using cached nvidia_cuda_runtime_cu11-11.7.99-py3-none-manylinux1_x86_64.whl (849 kB)
Collecting nvidia-cudnn-cu11==8.5.0.96; platform_system == "Linux"
  Using cached nvidia_cudnn_cu11-8.5.0.96-2-py3-none-manylinux1_x86_64.whl (557.1 MB)
Collecting nvidia-cuda-nvrtc-cu11==11.7.99; platform_system == "Linux"
  Using cached nvidia_cuda_nvrtc_cu11-11.7.99-2-py3-none-manylinux1_x86_64.whl (21.0 MB)
Requirement already satisfied: wheel in /opt/conda/lib/python3.8/site-packages (from nvidia-cublas-cu11==11.10.3.66; platform_system == "Linux"->torch>=1.8.1->torchmetrics) (0.35.1)
Requirement already satisfied: setuptools in /opt/conda/lib/python3.8/site-packages (from nvidia-cublas-cu11==11.10.3.66; platform_system == "Linux"->torch>=1.8.1->torchmetrics) (49.6.0.post20200814)
Installing collected packages: nvidia-cublas-cu11, nvidia-cuda-runtime-cu11, nvidia-cudnn-cu11, nvidia-cuda-nvrtc-cu11, torch, torchmetrics
Successfully installed nvidia-cublas-cu11-11.10.3.66 nvidia-cuda-nvrtc-cu11-11.7.99 nvidia-cuda-runtime-cu11-11.7.99 nvidia-cudnn-cu11-8.5.0.96 torch-1.13.1 torchmetrics-0.11.1
Collecting torch-fidelity
  Using cached torch_fidelity-0.3.0-py3-none-any.whl (37 kB)
Requirement already satisfied: Pillow in /opt/conda/lib/python3.8/site-packages (from torch-fidelity) (7.2.0)
Requirement already satisfied: scipy in /opt/conda/lib/python3.8/site-packages (from torch-fidelity) (1.4.1)
Requirement already satisfied: torch in /opt/conda/lib/python3.8/site-packages (from torch-fidelity) (1.13.1)
Requirement already satisfied: tqdm in /opt/conda/lib/python3.8/site-packages (from torch-fidelity) (4.48.2)
Requirement already satisfied: numpy in /opt/conda/lib/python3.8/site-packages (from torch-fidelity) (1.19.1)
Collecting torchvision
  Using cached torchvision-0.14.1-cp38-cp38-manylinux1_x86_64.whl (24.2 MB)
Requirement already satisfied: nvidia-cuda-runtime-cu11==11.7.99; platform_system == "Linux" in /opt/conda/lib/python3.8/site-packages (from torch->torch-fidelity) (11.7.99)
Requirement already satisfied: nvidia-cuda-nvrtc-cu11==11.7.99; platform_system == "Linux" in /opt/conda/lib/python3.8/site-packages (from torch->torch-fidelity) (11.7.99)
Requirement already satisfied: typing-extensions in /opt/conda/lib/python3.8/site-packages (from torch->torch-fidelity) (3.7.4.2)
Requirement already satisfied: nvidia-cudnn-cu11==8.5.0.96; platform_system == "Linux" in /opt/conda/lib/python3.8/site-packages (from torch->torch-fidelity) (8.5.0.96)
Requirement already satisfied: nvidia-cublas-cu11==11.10.3.66; platform_system == "Linux" in /opt/conda/lib/python3.8/site-packages (from torch->torch-fidelity) (11.10.3.66)
Requirement already satisfied: requests in /opt/conda/lib/python3.8/site-packages (from torchvision->torch-fidelity) (2.24.0)
Requirement already satisfied: wheel in /opt/conda/lib/python3.8/site-packages (from nvidia-cuda-runtime-cu11==11.7.99; platform_system == "Linux"->torch->torch-fidelity) (0.35.1)
Requirement already satisfied: setuptools in /opt/conda/lib/python3.8/site-packages (from nvidia-cuda-runtime-cu11==11.7.99; platform_system == "Linux"->torch->torch-fidelity) (49.6.0.post20200814)
Requirement already satisfied: chardet<4,>=3.0.2 in /opt/conda/lib/python3.8/site-packages (from requests->torchvision->torch-fidelity) (3.0.4)
Requirement already satisfied: idna<3,>=2.5 in /opt/conda/lib/python3.8/site-packages (from requests->torchvision->torch-fidelity) (2.10)
Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /opt/conda/lib/python3.8/site-packages (from requests->torchvision->torch-fidelity) (1.25.10)
Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/lib/python3.8/site-packages (from requests->torchvision->torch-fidelity) (2020.6.20)
Installing collected packages: torchvision, torch-fidelity
Successfully installed torch-fidelity-0.3.0 torchvision-0.14.1
Collecting typing-extensions
  Using cached typing_extensions-4.4.0-py3-none-any.whl (26 kB)
Installing collected packages: typing-extensions
  Attempting uninstall: typing-extensions
    Found existing installation: typing-extensions 3.7.4.2
    Uninstalling typing-extensions-3.7.4.2:
      Successfully uninstalled typing-extensions-3.7.4.2
Successfully installed typing-extensions-4.4.0
Requirement already satisfied: torchvision in /opt/conda/lib/python3.8/site-packages (0.14.1)
Requirement already satisfied: numpy in /opt/conda/lib/python3.8/site-packages (from torchvision) (1.19.1)
Requirement already satisfied: typing-extensions in /opt/conda/lib/python3.8/site-packages (from torchvision) (4.4.0)
Requirement already satisfied: requests in /opt/conda/lib/python3.8/site-packages (from torchvision) (2.24.0)
Requirement already satisfied: torch==1.13.1 in /opt/conda/lib/python3.8/site-packages (from torchvision) (1.13.1)
Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /opt/conda/lib/python3.8/site-packages (from torchvision) (7.2.0)
Requirement already satisfied: chardet<4,>=3.0.2 in /opt/conda/lib/python3.8/site-packages (from requests->torchvision) (3.0.4)
Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /opt/conda/lib/python3.8/site-packages (from requests->torchvision) (1.25.10)
Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/lib/python3.8/site-packages (from requests->torchvision) (2020.6.20)
Requirement already satisfied: idna<3,>=2.5 in /opt/conda/lib/python3.8/site-packages (from requests->torchvision) (2.10)
Requirement already satisfied: nvidia-cuda-runtime-cu11==11.7.99; platform_system == "Linux" in /opt/conda/lib/python3.8/site-packages (from torch==1.13.1->torchvision) (11.7.99)
Requirement already satisfied: nvidia-cublas-cu11==11.10.3.66; platform_system == "Linux" in /opt/conda/lib/python3.8/site-packages (from torch==1.13.1->torchvision) (11.10.3.66)
Requirement already satisfied: nvidia-cudnn-cu11==8.5.0.96; platform_system == "Linux" in /opt/conda/lib/python3.8/site-packages (from torch==1.13.1->torchvision) (8.5.0.96)
Requirement already satisfied: nvidia-cuda-nvrtc-cu11==11.7.99; platform_system == "Linux" in /opt/conda/lib/python3.8/site-packages (from torch==1.13.1->torchvision) (11.7.99)
Requirement already satisfied: wheel in /opt/conda/lib/python3.8/site-packages (from nvidia-cuda-runtime-cu11==11.7.99; platform_system == "Linux"->torch==1.13.1->torchvision) (0.35.1)
Requirement already satisfied: setuptools in /opt/conda/lib/python3.8/site-packages (from nvidia-cuda-runtime-cu11==11.7.99; platform_system == "Linux"->torch==1.13.1->torchvision) (49.6.0.post20200814)
In [8]:
# In case of errors from torch
# pip uninstall torch
# pip cache purge
# pip install torch -f https://download.pytorch.org/whl/torch_stable.html

Test on 10K Test Images

In [9]:
from torchmetrics.image.fid import FrechetInceptionDistance
import torch
fid=FrechetInceptionDistance(reset_real_features=False).to('cuda')
fid.reset()
real_images = loadTestSamples()
real_images = real_images.astype(np.float32)
real_images=torch.from_numpy(np.array(real_images))

#change dtypes, scale back first
real_images=(real_images+1)*(255/2)
real_images=real_images.type(torch.uint8)
#from 50,000x32x32x3 to 50,000x3x32x32
real_images = real_images.permute(0, 3, 1, 2)
real_images = real_images.to('cuda')
#COMPUTE FID

# create a dataloader with batch size 128
dataloader = torch.utils.data.DataLoader(real_images, batch_size=128, shuffle=True,
                                          drop_last=True
                                          )

# update the FID with real images
for batch,images in enumerate(dataloader):
    print(f'Updating with real images now...Batch{batch}',end='\r')
    fid.update(images, real=True)

print(f'\nFinished! Ready to take in generated image features')

def computeFID(generator,latentDim):
    fid.reset()
    #PREPARE DATA
    gen_images = generateFakeSamples(generator,latentDim,10000)[0]
#     gen_images=np.reshape(gen_images,(50000,3,32,32))
    gen_images = gen_images.astype(np.float32)
    gen_images = torch.from_numpy(np.array(gen_images))
    gen_images=(gen_images+1)*(255/2)
    gen_images=gen_images.type(torch.uint8)
    gen_images = gen_images.permute(0, 3, 1, 2)
    gen_images = gen_images.to('cuda')

    # create a dataloader with batch size 128
    dataloader_gen = torch.utils.data.DataLoader(gen_images, batch_size=128, shuffle=True,
                                            # pin_memory=True,
                                             drop_last=True)
    
    print('')
    # update the FID with generated images
    for batch,images in enumerate(dataloader_gen):
        print(f'Updating FID with generated images: Batch{batch}',end='\r')
        fid.update(images, real=False)
    
    print('\n')
    score = fid.compute()
    print(f'FID: {score.item()}')
    return score
Updating with real images now...Batch77
Finished! Ready to take in generated image features

Model 1 Baseline DCGAN

We will try to create a DCGAN now. Deep Convolutional Generative Adversarial Networks (DCGAN) is a variant of GANs where the generator and discriminator networks are composed of convolutional layers. DCGANs were proposed in 2016, and they have been used to generate high-quality images, such as faces and realistic landscapes. They are quite popular for CIFAR10 and more simple to implement than a complex architecture

In [10]:
# define the standalone discriminator model
def defineDisc(in_shape=(32, 32, 3)):
    model = tf.keras.models.Sequential()
    # normal
    model.add(tf.keras.layers.Conv2D(64, (3, 3), padding='same', input_shape=in_shape))
    model.add(tf.keras.layers.LeakyReLU(alpha=0.2))
    # downsample
    model.add(tf.keras.layers.Conv2D(64, (3, 3), strides=(2, 2), padding='same'))
    model.add(tf.keras.layers.LeakyReLU(alpha=0.2))
    # downsample
    model.add(tf.keras.layers.Conv2D(64, (3, 3), strides=(2, 2), padding='same'))
    model.add(tf.keras.layers.LeakyReLU(alpha=0.2))
    # downsample
    model.add(tf.keras.layers.Conv2D(64, (3, 3), strides=(2, 2), padding='same'))
    model.add(tf.keras.layers.LeakyReLU(alpha=0.2))
    # classifier
    model.add(tf.keras.layers.Flatten())
    model.add(tf.keras.layers.Dense(1, activation='sigmoid'))
    # compile model
    opt = tf.keras.optimizers.Adam(lr=0.0002, beta_1=0.5)
    model.compile(loss='binary_crossentropy', optimizer=opt, metrics=['accuracy'])
    return model

# define the standalone generator model
def defineGen(latent_dim):
    model = tf.keras.models.Sequential()
    # foundation for 4x4 image
    n_nodes = 128 * 4 * 4
    model.add(tf.keras.layers.Dense(n_nodes, input_dim=latent_dim))
    model.add(tf.keras.layers.LeakyReLU(alpha=0.2))
    model.add(tf.keras.layers.Reshape((4, 4, 128)))
    # upsample to 8x8
    model.add(tf.keras.layers.Conv2DTranspose(64, (4, 4), strides=(2, 2), padding='same'))
    model.add(tf.keras.layers.LeakyReLU(alpha=0.2))
    # upsample to 16x16
    model.add(tf.keras.layers.Conv2DTranspose(64, (4, 4), strides=(2, 2), padding='same'))
    model.add(tf.keras.layers.LeakyReLU(alpha=0.2))
    # upsample to 32x32
    model.add(tf.keras.layers.Conv2DTranspose(64, (4, 4), strides=(2, 2), padding='same'))
    model.add(tf.keras.layers.LeakyReLU(alpha=0.2))
    # output layer
    model.add(tf.keras.layers.Conv2D(3, (3, 3), activation='tanh', padding='same'))
    return model
In [11]:
# define the combined generator and discriminator model, for updating the generator
def defineGAN(genModel, discModel):
    # make weights in the discriminator not trainable
    discModel.trainable = False
    # connect them
    model = tf.keras.models.Sequential()
    # add generator
    model.add(genModel)
    # add the discriminator
    model.add(discModel)
    # compile model
    opt = tf.keras.optimizers.Adam(lr=0.0002,beta_1=0.5)
    model.compile(loss='binary_crossentropy', optimizer=opt)
    return model

# train the generator and discriminator
def trainGAN(genModel, discModel, gan_model, dataset, latentDim, n_epochs=200, n_batch=128,fid_frequency=50, name='model'):
    bat_per_epo = int(dataset.shape[0] / n_batch)
    half_batch = int(n_batch / 2)
    # manually enumerate epochs
    for i in range(n_epochs):
        # enumerate batches over the training set
        for j in range(bat_per_epo):
            # get randomly selected 'real' samples
            X_real, y_real = selectRealSamples(dataset, half_batch)
            # update discriminator model weights
            d_loss1, _ = discModel.train_on_batch(X_real, y_real)
            # generate 'fake' examples
            X_fake, y_fake = generateFakeSamples(genModel, latentDim, half_batch)
            # update discriminator model weights
            d_loss2, _ = discModel.train_on_batch(X_fake, y_fake)
            # prepare points in latent space as input for the generator
            X_gan = generateLatentPoints(latentDim, n_batch)
            # create inverted labels for the fake samples
            y_gan = np.ones((n_batch, 1))
            # update the generator via the discriminator's error
            g_loss = gan_model.train_on_batch(X_gan, y_gan)
            #progress bar per epoch 
            
            if j%25==0 or j==bat_per_epo-1:
                if j==bat_per_epo-1: #if last step
                    print('>epoch%d, %d/%d, d1=%.3f, d2=%.3f g=%.3f' % (i + 1, j + 1, bat_per_epo, d_loss1, d_loss2, g_loss))
                else:
                    print('>epoch%d, %d/%d, d1=%.3f, d2=%.3f g=%.3f' % (i + 1, j + 1, bat_per_epo, d_loss1, d_loss2, g_loss) , end='\r')
        
        #compute FID for every interval and last epoch
        if (i>0 and (i+1)%fid_frequency==0) or (i+1)==n_epochs:
            # Create a new model with the same architecture as the generator
            updated_gen_model = tf.keras.models.Sequential()
            updated_gen_model.add(gan_model.layers[0])
            
            # Set the weights of the new model to be the same as the generator portion of the gan_model
            updated_gen_model.set_weights(gan_model.layers[0].get_weights())
            computeFID(updated_gen_model,latentDim)
            
        # evaluate the model performance every few epochs
        if (i+1) % 2 == 0:
          summarisePerf(i, genModel, discModel, dataset, latentDim, name=name)
    return
In [12]:
# size of the latent space
latent_dim = 100
# create the discriminator
discModel = defineDisc()
# create the generator
genModel = defineGen(latent_dim)
# create the gan
gan_model = defineGAN(genModel, discModel)

print('***************DISCRIMINATOR**************')
discModel.summary()
print('***************GENERATOR**************')
genModel.summary()

gan_model.summary()
***************DISCRIMINATOR**************
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d (Conv2D)              (None, 32, 32, 64)        1792      
_________________________________________________________________
leaky_re_lu (LeakyReLU)      (None, 32, 32, 64)        0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 16, 16, 64)        36928     
_________________________________________________________________
leaky_re_lu_1 (LeakyReLU)    (None, 16, 16, 64)        0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 8, 8, 64)          36928     
_________________________________________________________________
leaky_re_lu_2 (LeakyReLU)    (None, 8, 8, 64)          0         
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 4, 4, 64)          36928     
_________________________________________________________________
leaky_re_lu_3 (LeakyReLU)    (None, 4, 4, 64)          0         
_________________________________________________________________
flatten (Flatten)            (None, 1024)              0         
_________________________________________________________________
dense (Dense)                (None, 1)                 1025      
=================================================================
Total params: 113,601
Trainable params: 0
Non-trainable params: 113,601
_________________________________________________________________
***************GENERATOR**************
Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_1 (Dense)              (None, 2048)              206848    
_________________________________________________________________
leaky_re_lu_4 (LeakyReLU)    (None, 2048)              0         
_________________________________________________________________
reshape (Reshape)            (None, 4, 4, 128)         0         
_________________________________________________________________
conv2d_transpose (Conv2DTran (None, 8, 8, 64)          131136    
_________________________________________________________________
leaky_re_lu_5 (LeakyReLU)    (None, 8, 8, 64)          0         
_________________________________________________________________
conv2d_transpose_1 (Conv2DTr (None, 16, 16, 64)        65600     
_________________________________________________________________
leaky_re_lu_6 (LeakyReLU)    (None, 16, 16, 64)        0         
_________________________________________________________________
conv2d_transpose_2 (Conv2DTr (None, 32, 32, 64)        65600     
_________________________________________________________________
leaky_re_lu_7 (LeakyReLU)    (None, 32, 32, 64)        0         
_________________________________________________________________
conv2d_4 (Conv2D)            (None, 32, 32, 3)         1731      
=================================================================
Total params: 470,915
Trainable params: 470,915
Non-trainable params: 0
_________________________________________________________________
Model: "sequential_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
sequential_1 (Sequential)    (None, 32, 32, 3)         470915    
_________________________________________________________________
sequential (Sequential)      (None, 1)                 113601    
=================================================================
Total params: 584,516
Trainable params: 470,915
Non-trainable params: 113,601
_________________________________________________________________
In [13]:
%%time
# load image data
x_train = loadRealSamples()
# train model
trainGAN(genModel, discModel, gan_model, x_train, latent_dim, n_epochs=50,n_batch=256,fid_frequency=10,name='model1')
>epoch1, 234/234, d1=0.373, d2=0.437 g=2.252
>epoch2, 234/234, d1=0.548, d2=0.364 g=1.728
>Accuracy real: 62%, fake: 89%
>epoch3, 234/234, d1=0.638, d2=0.452 g=1.324
>epoch4, 234/234, d1=0.873, d2=0.578 g=1.226
>Accuracy real: 32%, fake: 97%
>epoch5, 234/234, d1=0.570, d2=0.513 g=1.035
>epoch6, 234/234, d1=0.607, d2=0.777 g=0.742
>Accuracy real: 62%, fake: 54%
>epoch7, 234/234, d1=0.725, d2=0.690 g=0.788
>epoch8, 234/234, d1=0.533, d2=0.581 g=1.053
>Accuracy real: 73%, fake: 79%
>epoch9, 234/234, d1=0.703, d2=0.682 g=0.806
>epoch10, 234/234, d1=0.623, d2=0.693 g=0.808

Updating FID with generated images: Batch389

FID: 232.79071044921875
>Accuracy real: 65%, fake: 72%
>epoch11, 234/234, d1=0.601, d2=0.598 g=0.975
>epoch12, 234/234, d1=0.535, d2=0.637 g=0.992
>Accuracy real: 71%, fake: 76%
>epoch13, 234/234, d1=0.521, d2=0.728 g=0.880
>epoch14, 234/234, d1=0.629, d2=0.658 g=0.948
>Accuracy real: 57%, fake: 73%
>epoch15, 234/234, d1=0.573, d2=0.570 g=1.003
>epoch16, 234/234, d1=0.589, d2=0.614 g=1.035
>Accuracy real: 67%, fake: 80%
>epoch17, 234/234, d1=0.679, d2=0.603 g=0.941
>epoch18, 234/234, d1=0.616, d2=0.673 g=0.838
>Accuracy real: 60%, fake: 67%
>epoch19, 234/234, d1=0.667, d2=0.643 g=0.859
>epoch20, 234/234, d1=0.703, d2=0.675 g=0.869

Updating FID with generated images: Batch389

FID: 134.72511291503906
>Accuracy real: 49%, fake: 80%
>epoch21, 234/234, d1=0.687, d2=0.666 g=0.832
>epoch22, 234/234, d1=0.674, d2=0.657 g=0.902
>Accuracy real: 48%, fake: 69%
>epoch23, 234/234, d1=0.526, d2=0.637 g=0.873
>epoch24, 234/234, d1=0.673, d2=0.580 g=0.935
>Accuracy real: 45%, fake: 85%
>epoch25, 234/234, d1=0.680, d2=0.623 g=0.862
>epoch26, 234/234, d1=0.584, d2=0.701 g=0.847
>Accuracy real: 70%, fake: 84%
>epoch27, 234/234, d1=0.581, d2=0.613 g=0.933
>epoch28, 234/234, d1=0.617, d2=0.601 g=0.875
>Accuracy real: 53%, fake: 89%
>epoch29, 234/234, d1=0.651, d2=0.750 g=0.834
>epoch30, 234/234, d1=0.650, d2=0.633 g=0.866

Updating FID with generated images: Batch389

FID: 107.86116790771484
>Accuracy real: 45%, fake: 75%
>epoch31, 234/234, d1=0.638, d2=0.691 g=0.772
>epoch32, 234/234, d1=0.638, d2=0.698 g=0.740
>Accuracy real: 63%, fake: 71%
>epoch33, 234/234, d1=0.713, d2=0.631 g=0.815
>epoch34, 234/234, d1=0.722, d2=0.662 g=0.787
>Accuracy real: 36%, fake: 84%
>epoch35, 234/234, d1=0.657, d2=0.692 g=0.797
>epoch36, 234/234, d1=0.684, d2=0.676 g=0.769
>Accuracy real: 51%, fake: 73%
>epoch37, 234/234, d1=0.698, d2=0.637 g=0.803
>epoch38, 234/234, d1=0.664, d2=0.741 g=0.765
>Accuracy real: 45%, fake: 79%
>epoch39, 234/234, d1=0.640, d2=0.689 g=0.746
>epoch40, 234/234, d1=0.713, d2=0.657 g=0.777

Updating FID with generated images: Batch389

FID: 89.67537689208984
>Accuracy real: 38%, fake: 75%
>epoch41, 234/234, d1=0.650, d2=0.678 g=0.764
>epoch42, 234/234, d1=0.738, d2=0.631 g=0.801
>Accuracy real: 31%, fake: 88%
>epoch43, 234/234, d1=0.710, d2=0.657 g=0.795
>epoch44, 234/234, d1=0.683, d2=0.651 g=0.770
>Accuracy real: 51%, fake: 77%
>epoch45, 234/234, d1=0.721, d2=0.626 g=0.798
>epoch46, 234/234, d1=0.690, d2=0.663 g=0.774
>Accuracy real: 32%, fake: 83%
>epoch47, 234/234, d1=0.676, d2=0.648 g=0.792
>epoch48, 234/234, d1=0.618, d2=0.688 g=0.736
>Accuracy real: 63%, fake: 59%
>epoch49, 234/234, d1=0.671, d2=0.692 g=0.749
>epoch50, 234/234, d1=0.686, d2=0.637 g=0.803

Updating FID with generated images: Batch389

FID: 74.29251098632812
>Accuracy real: 37%, fake: 86%
CPU times: user 24min 38s, sys: 30min 46s, total: 55min 24s
Wall time: 19min 22s

Summary

Decent for first attempt but model quite long to get out of the 'noisy' phase where its just random patterns and not an image like CIFAR10. A lot of room to improve because the model was running into a lot of mode collapse (model creating variants of same image/pattern)

Model 2

  • More epochs ** will allow our model to train more, learn more and generate better images
  • Improved models ** More neurons and better layers will allow our model to learn more
  • Use lower batch size
    ** because GANs are highly sensitive to the balance between the generator and discriminator, and larger batch sizes can disrupt the balance
  • Testing batch normalisation: We will test the model with and without batch normalisation to see which is better ** Batch normalisation can improve training stability but might not always be the case for GANs

No Batch Normalisation

In [14]:
# define the standalone discriminator model
def defineDisc(in_shape=(32, 32, 3)):
    model = tf.keras.models.Sequential()
    # normal
    model.add(tf.keras.layers.Conv2D(128, (3, 3), padding='same', input_shape=in_shape))
    model.add(tf.keras.layers.LeakyReLU(alpha=0.2))
    # downsample
    model.add(tf.keras.layers.Conv2D(128, (3, 3), strides=(2, 2), padding='same'))
    model.add(tf.keras.layers.LeakyReLU(alpha=0.2))
    # downsample
    model.add(tf.keras.layers.Conv2D(64, (3, 3), strides=(2, 2), padding='same'))
    model.add(tf.keras.layers.LeakyReLU(alpha=0.2))
    # downsample
    model.add(tf.keras.layers.Conv2D(64, (3, 3), strides=(2, 2), padding='same'))
    model.add(tf.keras.layers.LeakyReLU(alpha=0.2))
    # classifier
    model.add(tf.keras.layers.Flatten())
    model.add(tf.keras.layers.Dense(1, activation='sigmoid'))
    # compile model
    opt = tf.keras.optimizers.Adam(lr=0.0002, beta_1=0.5)
    model.compile(loss='binary_crossentropy', optimizer=opt, metrics=['accuracy'])
    return model

# define the standalone generator model
def defineGen(latent_dim):
    model = tf.keras.models.Sequential()
    # foundation for 4x4 image
    n_nodes = 256 * 4 * 4
    model.add(tf.keras.layers.Dense(n_nodes, input_dim=latent_dim))
    model.add(tf.keras.layers.LeakyReLU(alpha=0.2))
    model.add(tf.keras.layers.Reshape((4, 4, 256)))
    # upsample to 8x8
    model.add(tf.keras.layers.Conv2DTranspose(64, (4, 4), strides=(2, 2), padding='same'))
    model.add(tf.keras.layers.LeakyReLU(alpha=0.2))
    # upsample to 16x16
    model.add(tf.keras.layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding='same'))
    model.add(tf.keras.layers.LeakyReLU(alpha=0.2))
    # upsample to 32x32
    model.add(tf.keras.layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding='same'))
    model.add(tf.keras.layers.LeakyReLU(alpha=0.2))
    # output layer
    model.add(tf.keras.layers.Conv2D(3, (3, 3), activation='tanh', padding='same'))
    return model
In [14]:
# define the combined generator and discriminator model, for updating the generator
def defineGAN(genModel, discModel):
    # make weights in the discriminator not trainable
    discModel.trainable = False
    # connect them
    model = tf.keras.models.Sequential()
    # add generator
    model.add(genModel)
    # add the discriminator
    model.add(discModel)
    # compile model
    opt = tf.keras.optimizers.Adam(lr=0.0002,beta_1=0.5)
    model.compile(loss='binary_crossentropy', optimizer=opt)
    return model

# train the generator and discriminator
def trainGAN(genModel, discModel, gan_model, dataset, latentDim, n_epochs=200, n_batch=128,fid_frequency=50,name='model'):
    bat_per_epo = int(dataset.shape[0] / n_batch)
    half_batch = int(n_batch / 2)
    # manually enumerate epochs
    for i in range(n_epochs):
        
        for j in range(bat_per_epo):
            # get randomly selected 'real' samples
            X_real, y_real = selectRealSamples(dataset, half_batch)
            # update discriminator model weights
            d_loss1, _ = discModel.train_on_batch(X_real, y_real)
            # generate 'fake' examples
            X_fake, y_fake = generateFakeSamples(genModel, latentDim, half_batch)
            # update discriminator model weights
            d_loss2, _ = discModel.train_on_batch(X_fake, y_fake)
            # prepare points in latent space as input for the generator
            X_gan = generateLatentPoints(latentDim, n_batch)
            # create inverted labels for the fake samples
            y_gan = np.ones((n_batch, 1))
            # update the generator via the discriminator's error
            g_loss = gan_model.train_on_batch(X_gan, y_gan)
            #progress bar per epoch 
            
            if j%25==0 or j==bat_per_epo-1:
                if j==bat_per_epo-1: #if last step
                    print('>epoch%d, %d/%d, d1=%.3f, d2=%.3f g=%.3f' % (i + 1, j + 1, bat_per_epo, d_loss1, d_loss2, g_loss))
                else:
                    print('>epoch%d, %d/%d, d1=%.3f, d2=%.3f g=%.3f' % (i + 1, j + 1, bat_per_epo, d_loss1, d_loss2, g_loss) , end='\r')
        
        #compute FID for every interval and last epoch
        if (i>0 and (i+1)%fid_frequency==0) or (i+1)==n_epochs:
            # Create a new model with the same architecture as the generator
            updated_gen_model = tf.keras.models.Sequential()
            updated_gen_model.add(gan_model.layers[0])
            
            # Set the weights of the new model to be the same as the generator portion of the gan_model
            updated_gen_model.set_weights(gan_model.layers[0].get_weights())
            computeFID(updated_gen_model,latentDim)
            
        # evaluate the model performance every few epochs
        if (i+1) % 2 == 0:
          summarisePerf(i, genModel, discModel, dataset, latentDim,name=name)
    return
In [16]:
# size of the latent space
latent_dim = 100
# create the discriminator
discModel = defineDisc()
# create the generator
genModel = defineGen(latent_dim)
# create the gan
gan_model = defineGAN(genModel, discModel)

print('***************DISCRIMINATOR**************')
discModel.summary()
print('***************GENERATOR**************')
genModel.summary()

gan_model.summary()
***************DISCRIMINATOR**************
Model: "sequential_8"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d_5 (Conv2D)            (None, 32, 32, 128)       3584      
_________________________________________________________________
leaky_re_lu_8 (LeakyReLU)    (None, 32, 32, 128)       0         
_________________________________________________________________
conv2d_6 (Conv2D)            (None, 16, 16, 128)       147584    
_________________________________________________________________
leaky_re_lu_9 (LeakyReLU)    (None, 16, 16, 128)       0         
_________________________________________________________________
conv2d_7 (Conv2D)            (None, 8, 8, 64)          73792     
_________________________________________________________________
leaky_re_lu_10 (LeakyReLU)   (None, 8, 8, 64)          0         
_________________________________________________________________
conv2d_8 (Conv2D)            (None, 4, 4, 64)          36928     
_________________________________________________________________
leaky_re_lu_11 (LeakyReLU)   (None, 4, 4, 64)          0         
_________________________________________________________________
flatten_1 (Flatten)          (None, 1024)              0         
_________________________________________________________________
dense_2 (Dense)              (None, 1)                 1025      
=================================================================
Total params: 262,913
Trainable params: 0
Non-trainable params: 262,913
_________________________________________________________________
***************GENERATOR**************
Model: "sequential_9"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_3 (Dense)              (None, 4096)              413696    
_________________________________________________________________
leaky_re_lu_12 (LeakyReLU)   (None, 4096)              0         
_________________________________________________________________
reshape_1 (Reshape)          (None, 4, 4, 256)         0         
_________________________________________________________________
conv2d_transpose_3 (Conv2DTr (None, 8, 8, 64)          262208    
_________________________________________________________________
leaky_re_lu_13 (LeakyReLU)   (None, 8, 8, 64)          0         
_________________________________________________________________
conv2d_transpose_4 (Conv2DTr (None, 16, 16, 128)       131200    
_________________________________________________________________
leaky_re_lu_14 (LeakyReLU)   (None, 16, 16, 128)       0         
_________________________________________________________________
conv2d_transpose_5 (Conv2DTr (None, 32, 32, 128)       262272    
_________________________________________________________________
leaky_re_lu_15 (LeakyReLU)   (None, 32, 32, 128)       0         
_________________________________________________________________
conv2d_9 (Conv2D)            (None, 32, 32, 3)         3459      
=================================================================
Total params: 1,072,835
Trainable params: 1,072,835
Non-trainable params: 0
_________________________________________________________________
Model: "sequential_10"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
sequential_9 (Sequential)    (None, 32, 32, 3)         1072835   
_________________________________________________________________
sequential_8 (Sequential)    (None, 1)                 262913    
=================================================================
Total params: 1,335,748
Trainable params: 1,072,835
Non-trainable params: 262,913
_________________________________________________________________
In [17]:
%%time
# load image data
x_train = loadRealSamples()
# train model
trainGAN(genModel, discModel, gan_model, x_train, latent_dim, n_epochs=100,n_batch=64,fid_frequency=10,name='model2NoBN')
>epoch1, 937/937, d1=0.467, d2=0.563 g=1.013
>epoch2, 937/937, d1=0.620, d2=0.575 g=1.049
>Accuracy real: 68%, fake: 85%
>epoch3, 937/937, d1=0.648, d2=0.489 g=1.119
>epoch4, 937/937, d1=0.615, d2=0.611 g=0.966
>Accuracy real: 69%, fake: 85%
>epoch5, 937/937, d1=0.776, d2=0.758 g=1.131
>epoch6, 937/937, d1=0.569, d2=0.681 g=1.150
>Accuracy real: 56%, fake: 85%
>epoch7, 937/937, d1=0.636, d2=0.655 g=0.789
>epoch8, 937/937, d1=0.653, d2=0.658 g=0.841
>Accuracy real: 39%, fake: 69%
>epoch9, 937/937, d1=0.795, d2=0.556 g=0.936
>epoch10, 937/937, d1=0.667, d2=0.713 g=0.824

Updating FID with generated images: Batch389

FID: 74.96601867675781
>Accuracy real: 47%, fake: 67%
>epoch11, 937/937, d1=0.730, d2=0.668 g=0.761
>epoch12, 937/937, d1=0.710, d2=0.647 g=0.724
>Accuracy real: 55%, fake: 74%
>epoch13, 937/937, d1=0.580, d2=0.718 g=0.702
>epoch14, 937/937, d1=0.702, d2=0.641 g=0.793
>Accuracy real: 35%, fake: 85%
>epoch15, 937/937, d1=0.667, d2=0.712 g=0.751
>epoch16, 937/937, d1=0.689, d2=0.634 g=0.762
>Accuracy real: 35%, fake: 81%
>epoch17, 937/937, d1=0.705, d2=0.740 g=0.768
>epoch18, 937/937, d1=0.686, d2=0.631 g=0.752
>Accuracy real: 46%, fake: 78%
>epoch19, 937/937, d1=0.663, d2=0.623 g=0.809
>epoch20, 937/937, d1=0.710, d2=0.671 g=0.750

Updating FID with generated images: Batch389

FID: 61.5596923828125
>Accuracy real: 46%, fake: 69%
>epoch21, 937/937, d1=0.714, d2=0.643 g=0.763
>epoch22, 937/937, d1=0.707, d2=0.663 g=0.727
>Accuracy real: 51%, fake: 71%
>epoch23, 937/937, d1=0.661, d2=0.654 g=0.753
>epoch24, 937/937, d1=0.718, d2=0.662 g=0.738
>Accuracy real: 39%, fake: 85%
>epoch25, 937/937, d1=0.641, d2=0.660 g=0.752
>epoch26, 937/937, d1=0.645, d2=0.727 g=0.764
>Accuracy real: 39%, fake: 73%
>epoch27, 937/937, d1=0.710, d2=0.627 g=0.800
>epoch28, 937/937, d1=0.587, d2=0.699 g=0.807
>Accuracy real: 54%, fake: 79%
>epoch29, 937/937, d1=0.633, d2=0.700 g=0.718
>epoch30, 937/937, d1=0.690, d2=0.679 g=0.774

Updating FID with generated images: Batch389

FID: 73.81658935546875
>Accuracy real: 49%, fake: 75%
>epoch31, 937/937, d1=0.666, d2=0.650 g=0.739
>epoch32, 937/937, d1=0.678, d2=0.633 g=0.771
>Accuracy real: 45%, fake: 81%
>epoch33, 937/937, d1=0.617, d2=0.671 g=0.739
>epoch34, 937/937, d1=0.664, d2=0.709 g=0.762
>Accuracy real: 50%, fake: 77%
>epoch35, 937/937, d1=0.698, d2=0.653 g=0.806
>epoch36, 937/937, d1=0.622, d2=0.660 g=0.779
>Accuracy real: 50%, fake: 78%
>epoch37, 937/937, d1=0.649, d2=0.642 g=0.774
>epoch38, 937/937, d1=0.697, d2=0.616 g=0.801
>Accuracy real: 49%, fake: 76%
>epoch39, 937/937, d1=0.662, d2=0.654 g=0.790
>epoch40, 937/937, d1=0.631, d2=0.630 g=0.777

Updating FID with generated images: Batch389

FID: 48.53939437866211
>Accuracy real: 41%, fake: 78%
>epoch41, 937/937, d1=0.716, d2=0.626 g=0.781
>epoch42, 937/937, d1=0.747, d2=0.705 g=0.828
>Accuracy real: 39%, fake: 72%
>epoch43, 937/937, d1=0.661, d2=0.613 g=0.750
>epoch44, 937/937, d1=0.682, d2=0.725 g=0.802
>Accuracy real: 57%, fake: 69%
>epoch45, 937/937, d1=0.660, d2=0.654 g=0.813
>epoch46, 937/937, d1=0.673, d2=0.605 g=0.836
>Accuracy real: 43%, fake: 85%
>epoch47, 937/937, d1=0.635, d2=0.663 g=0.810
>epoch48, 937/937, d1=0.657, d2=0.632 g=0.780
>Accuracy real: 49%, fake: 71%
>epoch49, 937/937, d1=0.634, d2=0.671 g=0.757
>epoch50, 937/937, d1=0.607, d2=0.670 g=0.844

Updating FID with generated images: Batch389

FID: 38.83604049682617
>Accuracy real: 45%, fake: 75%
>epoch51, 937/937, d1=0.619, d2=0.666 g=0.881
>epoch52, 937/937, d1=0.564, d2=0.661 g=0.856
>Accuracy real: 51%, fake: 83%
>epoch53, 937/937, d1=0.569, d2=0.856 g=0.810
>epoch54, 937/937, d1=0.660, d2=0.641 g=0.885
>Accuracy real: 45%, fake: 83%
>epoch55, 937/937, d1=0.667, d2=0.679 g=0.827
>epoch56, 937/937, d1=0.778, d2=0.629 g=0.875
>Accuracy real: 50%, fake: 84%
>epoch57, 937/937, d1=0.745, d2=0.476 g=1.162
>epoch58, 937/937, d1=0.562, d2=0.713 g=0.901
>Accuracy real: 57%, fake: 83%
>epoch59, 937/937, d1=0.609, d2=0.640 g=0.919
>epoch60, 937/937, d1=0.541, d2=0.858 g=0.841

Updating FID with generated images: Batch389

FID: 40.037620544433594
>Accuracy real: 59%, fake: 81%
>epoch61, 937/937, d1=0.760, d2=0.595 g=0.917
>epoch62, 937/937, d1=0.691, d2=0.556 g=0.842
>Accuracy real: 47%, fake: 75%
>epoch63, 937/937, d1=0.639, d2=0.716 g=0.811
>epoch64, 937/937, d1=0.681, d2=0.732 g=0.839
>Accuracy real: 51%, fake: 88%
>epoch65, 937/937, d1=0.697, d2=0.629 g=0.909
>epoch66, 937/937, d1=0.630, d2=0.654 g=0.993
>Accuracy real: 48%, fake: 86%
>epoch67, 937/937, d1=0.731, d2=0.597 g=0.889
>epoch68, 937/937, d1=0.626, d2=0.576 g=0.914
>Accuracy real: 55%, fake: 73%
>epoch69, 937/937, d1=0.633, d2=0.550 g=0.846
>epoch70, 937/937, d1=0.681, d2=0.572 g=0.902

Updating FID with generated images: Batch389

FID: 40.42504119873047
>Accuracy real: 47%, fake: 87%
>epoch71, 937/937, d1=0.745, d2=0.598 g=0.947
>epoch72, 937/937, d1=0.697, d2=0.609 g=0.958
>Accuracy real: 46%, fake: 90%
>epoch73, 937/937, d1=0.692, d2=0.569 g=0.889
>epoch74, 937/937, d1=0.566, d2=0.664 g=0.929
>Accuracy real: 49%, fake: 87%
>epoch75, 937/937, d1=0.660, d2=0.585 g=0.964
>epoch76, 937/937, d1=0.629, d2=0.497 g=0.893
>Accuracy real: 51%, fake: 87%
>epoch77, 937/937, d1=0.778, d2=0.593 g=0.909
>epoch78, 937/937, d1=0.569, d2=0.642 g=0.878
>Accuracy real: 62%, fake: 73%
>epoch79, 937/937, d1=0.550, d2=0.722 g=0.916
>epoch80, 937/937, d1=0.585, d2=0.611 g=0.949

Updating FID with generated images: Batch389

FID: 42.541011810302734
>Accuracy real: 52%, fake: 85%
>epoch81, 937/937, d1=0.588, d2=0.566 g=0.925
>epoch82, 937/937, d1=0.743, d2=0.490 g=1.005
>Accuracy real: 55%, fake: 89%
>epoch83, 937/937, d1=0.656, d2=0.674 g=0.931
>epoch84, 937/937, d1=0.619, d2=0.720 g=0.938
>Accuracy real: 51%, fake: 85%
>epoch85, 937/937, d1=0.747, d2=0.581 g=0.955
>epoch86, 937/937, d1=0.531, d2=0.575 g=0.954
>Accuracy real: 55%, fake: 84%
>epoch87, 937/937, d1=0.690, d2=0.581 g=1.062
>epoch88, 937/937, d1=0.667, d2=0.673 g=0.926
>Accuracy real: 49%, fake: 71%
>epoch89, 937/937, d1=0.597, d2=0.685 g=1.033
>epoch90, 937/937, d1=0.607, d2=0.611 g=0.914

Updating FID with generated images: Batch389

FID: 43.484256744384766
>Accuracy real: 53%, fake: 82%
>epoch91, 937/937, d1=0.524, d2=0.521 g=0.912
>epoch92, 937/937, d1=0.655, d2=0.685 g=0.878
>Accuracy real: 45%, fake: 88%
>epoch93, 937/937, d1=0.573, d2=0.719 g=0.998
>epoch94, 937/937, d1=0.703, d2=0.575 g=1.007
>Accuracy real: 51%, fake: 85%
>epoch95, 937/937, d1=0.647, d2=0.533 g=0.861
>epoch96, 937/937, d1=0.646, d2=0.608 g=1.042
>Accuracy real: 46%, fake: 89%
>epoch97, 937/937, d1=0.614, d2=0.603 g=0.913
>epoch98, 937/937, d1=0.511, d2=0.616 g=1.176
>Accuracy real: 53%, fake: 84%
>epoch99, 937/937, d1=0.684, d2=0.546 g=0.983
>epoch100, 937/937, d1=0.639, d2=0.660 g=0.843

Updating FID with generated images: Batch389

FID: 43.99592590332031
>Accuracy real: 59%, fake: 79%
CPU times: user 1h 39min 31s, sys: 56min 26s, total: 2h 35min 58s
Wall time: 1h 37min

FID of ~43 is a great improvement from our baseline, let's see if batch normalisation will improve even further

Batch Normalisation

Only use in Generator model, otherwise model collapsed

In [98]:
# define the standalone discriminator model
def defineDisc(in_shape=(32, 32, 3)):
    model = tf.keras.models.Sequential()
    # normal
    model.add(tf.keras.layers.Conv2D(128, (3, 3), padding='same', input_shape=in_shape))
    model.add(tf.keras.layers.LeakyReLU(alpha=0.2))
    # downsample
    model.add(tf.keras.layers.Conv2D(128, (3, 3), strides=(2, 2), padding='same'))
    model.add(tf.keras.layers.LeakyReLU(alpha=0.2))
    # downsample
    model.add(tf.keras.layers.Conv2D(64, (3, 3), strides=(2, 2), padding='same'))
    model.add(tf.keras.layers.LeakyReLU(alpha=0.2))
    # downsample
    model.add(tf.keras.layers.Conv2D(64, (3, 3), strides=(2, 2), padding='same'))
    model.add(tf.keras.layers.LeakyReLU(alpha=0.2))
    # classifier
    model.add(tf.keras.layers.Flatten())
    model.add(tf.keras.layers.Dropout(0.4))
    model.add(tf.keras.layers.Dense(1, activation='sigmoid'))
    # compile model
    opt = tf.keras.optimizers.Adam(lr=0.0002, beta_1=0.5)
    model.compile(loss='binary_crossentropy', optimizer=opt, metrics=['accuracy'])
    return model

# define the standalone generator model
def defineGen(latent_dim):
    model = tf.keras.models.Sequential()
    # foundation for 4x4 image
    n_nodes = 256 * 4 * 4
    model.add(tf.keras.layers.Dense(n_nodes, input_dim=latent_dim))
    model.add(tf.keras.layers.BatchNormalization(momentum=0.9))
    model.add(tf.keras.layers.LeakyReLU(alpha=0.2))
    model.add(tf.keras.layers.Reshape((4, 4, 256)))
    # upsample to 8x8
    model.add(tf.keras.layers.Conv2DTranspose(64, (4, 4), strides=(2, 2), padding='same'))
    model.add(tf.keras.layers.BatchNormalization(momentum=0.9))
    model.add(tf.keras.layers.LeakyReLU(alpha=0.2))
    # upsample to 16x16
    model.add(tf.keras.layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding='same'))
    model.add(tf.keras.layers.BatchNormalization(momentum=0.9))
    model.add(tf.keras.layers.LeakyReLU(alpha=0.2))
    # upsample to 32x32
    model.add(tf.keras.layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding='same'))
    model.add(tf.keras.layers.BatchNormalization(momentum=0.9))
    model.add(tf.keras.layers.LeakyReLU(alpha=0.2))
    # output layer
    model.add(tf.keras.layers.Conv2D(3, (3, 3), activation='tanh', padding='same'))
    return model
In [99]:
# size of the latent space
latent_dim = 100
# create the discriminator
discModel = defineDisc()
# create the generator
genModel = defineGen(latent_dim)
# create the gan
gan_model = defineGAN(genModel, discModel)

print('***************DISCRIMINATOR**************')
discModel.summary()
print('***************GENERATOR**************')
genModel.summary()

gan_model.summary()
***************DISCRIMINATOR**************
Model: "sequential_75"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d_15 (Conv2D)           (None, 32, 32, 128)       3584      
_________________________________________________________________
leaky_re_lu_24 (LeakyReLU)   (None, 32, 32, 128)       0         
_________________________________________________________________
conv2d_16 (Conv2D)           (None, 16, 16, 128)       147584    
_________________________________________________________________
leaky_re_lu_25 (LeakyReLU)   (None, 16, 16, 128)       0         
_________________________________________________________________
conv2d_17 (Conv2D)           (None, 8, 8, 64)          73792     
_________________________________________________________________
leaky_re_lu_26 (LeakyReLU)   (None, 8, 8, 64)          0         
_________________________________________________________________
conv2d_18 (Conv2D)           (None, 4, 4, 64)          36928     
_________________________________________________________________
leaky_re_lu_27 (LeakyReLU)   (None, 4, 4, 64)          0         
_________________________________________________________________
flatten_3 (Flatten)          (None, 1024)              0         
_________________________________________________________________
dropout_3 (Dropout)          (None, 1024)              0         
_________________________________________________________________
dense_6 (Dense)              (None, 1)                 1025      
=================================================================
Total params: 262,913
Trainable params: 0
Non-trainable params: 262,913
_________________________________________________________________
***************GENERATOR**************
Model: "sequential_76"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_7 (Dense)              (None, 4096)              413696    
_________________________________________________________________
batch_normalization_12 (Batc (None, 4096)              16384     
_________________________________________________________________
leaky_re_lu_28 (LeakyReLU)   (None, 4096)              0         
_________________________________________________________________
reshape_3 (Reshape)          (None, 4, 4, 256)         0         
_________________________________________________________________
conv2d_transpose_9 (Conv2DTr (None, 8, 8, 64)          262208    
_________________________________________________________________
batch_normalization_13 (Batc (None, 8, 8, 64)          256       
_________________________________________________________________
leaky_re_lu_29 (LeakyReLU)   (None, 8, 8, 64)          0         
_________________________________________________________________
conv2d_transpose_10 (Conv2DT (None, 16, 16, 128)       131200    
_________________________________________________________________
batch_normalization_14 (Batc (None, 16, 16, 128)       512       
_________________________________________________________________
leaky_re_lu_30 (LeakyReLU)   (None, 16, 16, 128)       0         
_________________________________________________________________
conv2d_transpose_11 (Conv2DT (None, 32, 32, 128)       262272    
_________________________________________________________________
batch_normalization_15 (Batc (None, 32, 32, 128)       512       
_________________________________________________________________
leaky_re_lu_31 (LeakyReLU)   (None, 32, 32, 128)       0         
_________________________________________________________________
conv2d_19 (Conv2D)           (None, 32, 32, 3)         3459      
=================================================================
Total params: 1,090,499
Trainable params: 1,081,667
Non-trainable params: 8,832
_________________________________________________________________
Model: "sequential_77"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
sequential_76 (Sequential)   (None, 32, 32, 3)         1090499   
_________________________________________________________________
sequential_75 (Sequential)   (None, 1)                 262913    
=================================================================
Total params: 1,353,412
Trainable params: 1,081,667
Non-trainable params: 271,745
_________________________________________________________________
In [20]:
%%time
# load image data
x_train = loadRealSamples()
# train model
trainGAN(genModel, discModel, gan_model, x_train, latent_dim, n_epochs=100,n_batch=64,fid_frequency=10,name='model2BN')
>epoch1, 937/937, d1=0.319, d2=0.596 g=0.855
>epoch2, 937/937, d1=0.425, d2=0.509 g=1.139
>Accuracy real: 59%, fake: 97%
>epoch3, 937/937, d1=0.276, d2=0.168 g=2.034
>epoch4, 937/937, d1=0.722, d2=0.864 g=2.849
>Accuracy real: 68%, fake: 99%
>epoch5, 937/937, d1=0.155, d2=0.327 g=2.019
>epoch6, 937/937, d1=0.574, d2=0.225 g=3.113
>Accuracy real: 85%, fake: 100%
>epoch7, 937/937, d1=0.455, d2=0.162 g=2.907
>epoch8, 937/937, d1=0.950, d2=0.401 g=1.489
>Accuracy real: 71%, fake: 84%
>epoch9, 937/937, d1=0.471, d2=0.147 g=2.742
>epoch10, 937/937, d1=0.308, d2=0.283 g=1.742

Updating FID with generated images: Batch389

FID: 193.77963256835938
>Accuracy real: 87%, fake: 89%
>epoch11, 937/937, d1=0.460, d2=0.188 g=2.894
>epoch12, 937/937, d1=0.643, d2=0.321 g=1.493
>Accuracy real: 77%, fake: 91%
>epoch13, 937/937, d1=0.587, d2=0.462 g=1.334
>epoch14, 937/937, d1=0.536, d2=0.456 g=1.279
>Accuracy real: 65%, fake: 95%
>epoch15, 937/937, d1=0.638, d2=0.451 g=1.499
>epoch16, 937/937, d1=0.389, d2=0.378 g=1.481
>Accuracy real: 65%, fake: 95%
>epoch17, 937/937, d1=0.483, d2=0.490 g=1.013
>epoch18, 937/937, d1=0.472, d2=0.425 g=1.369
>Accuracy real: 66%, fake: 95%
>epoch19, 937/937, d1=0.598, d2=0.472 g=1.225
>epoch20, 937/937, d1=0.567, d2=0.617 g=1.083

Updating FID with generated images: Batch389

FID: 95.1962890625
>Accuracy real: 58%, fake: 83%
>epoch21, 937/937, d1=0.581, d2=0.518 g=1.021
>epoch22, 937/937, d1=0.691, d2=0.591 g=1.005
>Accuracy real: 59%, fake: 88%
>epoch23, 937/937, d1=0.648, d2=0.496 g=1.030
>epoch24, 937/937, d1=0.616, d2=0.571 g=0.981
>Accuracy real: 54%, fake: 82%
>epoch25, 937/937, d1=0.561, d2=0.542 g=1.004
>epoch26, 937/937, d1=0.633, d2=0.585 g=0.976
>Accuracy real: 55%, fake: 87%
>epoch27, 937/937, d1=0.581, d2=0.576 g=0.984
>epoch28, 937/937, d1=0.494, d2=0.703 g=0.997
>Accuracy real: 57%, fake: 87%
>epoch29, 937/937, d1=0.604, d2=0.608 g=1.035
>epoch30, 937/937, d1=0.624, d2=0.578 g=0.959

Updating FID with generated images: Batch389

FID: 57.23186111450195
>Accuracy real: 51%, fake: 90%
>epoch31, 937/937, d1=0.736, d2=0.588 g=0.952
>epoch32, 937/937, d1=0.730, d2=0.587 g=0.953
>Accuracy real: 57%, fake: 85%
>epoch33, 937/937, d1=0.705, d2=0.613 g=1.022
>epoch34, 937/937, d1=0.734, d2=0.562 g=0.943
>Accuracy real: 55%, fake: 87%
>epoch35, 937/937, d1=0.673, d2=0.449 g=1.015
>epoch36, 937/937, d1=0.596, d2=0.582 g=0.986
>Accuracy real: 55%, fake: 83%
>epoch37, 937/937, d1=0.651, d2=0.589 g=0.964
>epoch38, 937/937, d1=0.714, d2=0.600 g=1.006
>Accuracy real: 58%, fake: 86%
>epoch39, 937/937, d1=0.581, d2=0.518 g=0.910
>epoch40, 937/937, d1=0.595, d2=0.632 g=0.963

Updating FID with generated images: Batch389

FID: 54.51516342163086
>Accuracy real: 53%, fake: 83%
>epoch41, 937/937, d1=0.760, d2=0.670 g=0.946
>epoch42, 937/937, d1=0.659, d2=0.578 g=0.953
>Accuracy real: 51%, fake: 89%
>epoch43, 937/937, d1=0.681, d2=0.603 g=0.978
>epoch44, 937/937, d1=0.540, d2=0.655 g=0.984
>Accuracy real: 55%, fake: 86%
>epoch45, 937/937, d1=0.612, d2=0.623 g=0.961
>epoch46, 937/937, d1=0.594, d2=0.647 g=0.946
>Accuracy real: 59%, fake: 87%
>epoch47, 937/937, d1=0.583, d2=0.631 g=0.921
>epoch48, 937/937, d1=0.619, d2=0.633 g=0.935
>Accuracy real: 62%, fake: 88%
>epoch49, 937/937, d1=0.700, d2=0.582 g=0.981
>epoch50, 937/937, d1=0.617, d2=0.667 g=0.979

Updating FID with generated images: Batch389

FID: 46.495052337646484
>Accuracy real: 63%, fake: 87%
>epoch51, 937/937, d1=0.577, d2=0.616 g=0.915
>epoch52, 937/937, d1=0.701, d2=0.524 g=0.970
>Accuracy real: 60%, fake: 88%
>epoch53, 937/937, d1=0.635, d2=0.493 g=1.065
>epoch54, 937/937, d1=0.572, d2=0.598 g=0.949
>Accuracy real: 59%, fake: 92%
>epoch55, 937/937, d1=0.575, d2=0.669 g=1.019
>epoch56, 937/937, d1=0.646, d2=0.671 g=0.970
>Accuracy real: 58%, fake: 86%
>epoch57, 937/937, d1=0.639, d2=0.599 g=0.996
>epoch58, 937/937, d1=0.685, d2=0.689 g=1.021
>Accuracy real: 56%, fake: 87%
>epoch59, 937/937, d1=0.641, d2=0.620 g=0.916
>epoch60, 937/937, d1=0.485, d2=0.510 g=0.886

Updating FID with generated images: Batch389

FID: 46.95098114013672
>Accuracy real: 47%, fake: 86%
>epoch61, 937/937, d1=0.677, d2=0.604 g=0.936
>epoch62, 937/937, d1=0.589, d2=0.587 g=1.002
>Accuracy real: 51%, fake: 85%
>epoch63, 937/937, d1=0.620, d2=0.638 g=0.950
>epoch64, 937/937, d1=0.610, d2=0.776 g=0.914
>Accuracy real: 66%, fake: 79%
>epoch65, 937/937, d1=0.733, d2=0.599 g=1.054
>epoch66, 937/937, d1=0.583, d2=0.547 g=0.970
>Accuracy real: 51%, fake: 85%
>epoch67, 937/937, d1=0.623, d2=0.540 g=0.923
>epoch68, 937/937, d1=0.559, d2=0.778 g=0.990
>Accuracy real: 43%, fake: 95%
>epoch69, 937/937, d1=0.639, d2=0.594 g=0.931
>epoch70, 937/937, d1=0.690, d2=0.672 g=0.906

Updating FID with generated images: Batch389

FID: 44.9049072265625
>Accuracy real: 62%, fake: 81%
>epoch71, 937/937, d1=0.651, d2=0.589 g=0.992
>epoch72, 937/937, d1=0.567, d2=0.630 g=1.025
>Accuracy real: 48%, fake: 86%
>epoch73, 937/937, d1=0.752, d2=0.717 g=0.905
>epoch74, 937/937, d1=0.631, d2=0.522 g=0.979
>Accuracy real: 53%, fake: 85%
>epoch75, 937/937, d1=0.629, d2=0.588 g=0.903
>epoch76, 937/937, d1=0.707, d2=0.565 g=0.904
>Accuracy real: 47%, fake: 88%
>epoch77, 937/937, d1=0.589, d2=0.610 g=1.040
>epoch78, 937/937, d1=0.744, d2=0.549 g=0.924
>Accuracy real: 47%, fake: 93%
>epoch79, 937/937, d1=0.644, d2=0.603 g=0.858
>epoch80, 937/937, d1=0.694, d2=0.701 g=1.033

Updating FID with generated images: Batch389

FID: 42.80244445800781
>Accuracy real: 53%, fake: 91%
>epoch81, 937/937, d1=0.586, d2=0.653 g=0.954
>epoch82, 937/937, d1=0.673, d2=0.568 g=1.004
>Accuracy real: 52%, fake: 89%
>epoch83, 937/937, d1=0.769, d2=0.494 g=0.976
>epoch84, 937/937, d1=0.736, d2=0.534 g=0.955
>Accuracy real: 51%, fake: 84%
>epoch85, 937/937, d1=0.679, d2=0.702 g=1.023
>epoch86, 937/937, d1=0.634, d2=0.704 g=1.060
>Accuracy real: 45%, fake: 90%
>epoch87, 937/937, d1=0.609, d2=0.616 g=1.101
>epoch88, 937/937, d1=0.677, d2=0.577 g=1.060
>Accuracy real: 57%, fake: 92%
>epoch89, 937/937, d1=0.785, d2=0.637 g=0.938
>epoch90, 937/937, d1=0.672, d2=0.800 g=0.928

Updating FID with generated images: Batch389

FID: 44.24906539916992
>Accuracy real: 63%, fake: 88%
>epoch91, 937/937, d1=0.769, d2=0.625 g=1.060
>epoch92, 937/937, d1=0.483, d2=0.783 g=1.016
>Accuracy real: 57%, fake: 86%
>epoch93, 937/937, d1=0.693, d2=0.631 g=0.980
>epoch94, 937/937, d1=0.589, d2=0.626 g=1.056
>Accuracy real: 62%, fake: 91%
>epoch95, 937/937, d1=0.696, d2=0.636 g=1.051
>epoch96, 937/937, d1=0.597, d2=0.591 g=1.047
>Accuracy real: 64%, fake: 91%
>epoch97, 937/937, d1=0.581, d2=0.505 g=0.953
>epoch98, 937/937, d1=0.560, d2=0.673 g=1.038
>Accuracy real: 56%, fake: 85%
>epoch99, 937/937, d1=0.733, d2=0.522 g=0.935
>epoch100, 937/937, d1=0.585, d2=0.580 g=0.907

Updating FID with generated images: Batch389

FID: 39.03178405761719
>Accuracy real: 52%, fake: 90%
CPU times: user 1h 17min 54s, sys: 45min 28s, total: 2h 3min 23s
Wall time: 1h 21min 21s

Summary

  • No Batch Norm FID: 43
  • With Batch Norm FID: 39 Therefore, we will use Batch Normalisation in our generator models from now on. Another thing is that we saw the FIDs actually increase near the end of training. This could be due to overfitting so we will try to correct that in the next model so we can get more stable training for longer duration

Model 3

  • Image Data Generator
  • Label smoothing
  • Improve training loop Fixed seed evaluation Learning rate scheduler Loss storing Gaussian weight initialisation ** Save best FID model

Image Data Generator

We will do light augmentation because too much augmentation can cause the generator to learn the distribution of augmented data instead of real data

In [13]:
def createImageGen():
  datagen = tf.keras.preprocessing.image.ImageDataGenerator(
      rotation_range=15,
      width_shift_range=0.15,
      height_shift_range=0.15,
      zoom_range=0.2,
      horizontal_flip=True
      )
  return datagen

#function to reverse scaling/preprocessing steps, for data visualisation purposes
def reverseProcess(X):
    X = (X *127.5)+127.5
    X = X.astype('uint8')
    return X
    
In [10]:
#visualise some examples
datagen=createImageGen()
x_train=loadRealSamples()
images=x_train[0:5]
augImages= datagen.flow(images, shuffle=False).next()
fig=plt.figure(figsize=(8,8))
fig.subplots_adjust(wspace=0.1)
count=1
for img,augImg in zip(reverseProcess(images),reverseProcess(augImages)):
    plt.subplot(5,2,count)
    plt.axis('off')
    plt.imshow(img)
    count+=1
    
    plt.subplot(5,2,count)
    plt.axis('off')
    plt.imshow(augImg )
    count+=1
plt.suptitle('Original                                     Augmented',fontsize=18)
plt.show()

Fixed seed evaluation

Now when the code displays the generated images the seed of the latent points will be the same. This way we can more easily see the progress of the model as it is trying to generate the same image everytime (doesnt affect actual training of model only visualisation)

In [11]:
import random #use python random library for fixed seeds

def generateFixedLatents(latentDim, n_samples):
    random.seed(0) #fixed seed
    # generate points in the latent space
    x_input = [random.gauss(0, 1) for i in range(latentDim * n_samples)]
    # reshape into a batch of inputs for the network
    x_input = np.array(x_input).reshape(n_samples, latentDim)
    return x_input


# use the generator to generate n fake examples, with class labels
def generateFixedSamples(genModel, latentDim, n_samples):
    # generate points in latent space
    x_input = generateFixedLatents(latentDim, n_samples)
    # predict outputs
    X = genModel.predict(x_input)
    # create 'fake' class labels (0)
    y = np.zeros((n_samples, 1))
    return X, y


# evaluate the discriminator, plot generated images, save generator model
def summarisePerf2(epoch, genModel, discModel, dataset, latentDim, name='model', n_samples=150):
    # prepare real samples
    X_real, y_real = selectRealSamples(dataset, n_samples)
    # evaluate discriminator on real examples
    _, acc_real = discModel.evaluate(X_real, y_real, verbose=0)
    # prepare fake examples
    x_fake, y_fake = generateFixedSamples(genModel, latentDim, n_samples)
    # evaluate discriminator on fake examples
    _, acc_fake = discModel.evaluate(x_fake, y_fake, verbose=0)
    # summarize discriminator performance
    print('>Accuracy real: %.0f%%, fake: %.0f%%' % (acc_real * 100, acc_fake * 100))
    # save plot
    savePlot(x_fake, epoch,name,save_frequency=1)

One-sided Label smoothing

Change positive labels (1) to a smooth range, this is known as soft label smoothing. This approach helps to improve the model's generalization ability by making it more robust to noise in the target labels.

In [12]:
def smooth_positive_labels(y):
 return y - 0.1 + (np.random.random(y.shape) * 0.1) #range of 0.9 to 1

Gaussian Weight Initialisation

Gaussian weight initialization is used to initialize the weights of layers with random values from a normal (Gaussian) distribution. This helps to prevent the model from getting stuck in poor local minima, leading to better convergence during training

In [9]:
def gaussian_weight_init(shape, dtype=None):
  mean = 0.0
  std = 0.1
  return K.random_normal(shape, mean=mean, stddev=std, dtype=dtype)

Model architecture left unchanged, but use gaussian weights now

In [14]:
# define the standalone discriminator model
def defineDisc(in_shape=(32, 32, 3)):
    model = tf.keras.models.Sequential()
    # normal
    model.add(tf.keras.layers.Conv2D(128, (3, 3), padding='same', input_shape=in_shape,kernel_initializer=gaussian_weight_init))
    model.add(tf.keras.layers.LeakyReLU(alpha=0.2))
    # downsample
    model.add(tf.keras.layers.Conv2D(128, (3, 3), strides=(2, 2), padding='same',kernel_initializer=gaussian_weight_init))
    model.add(tf.keras.layers.LeakyReLU(alpha=0.2))
    # downsample
    model.add(tf.keras.layers.Conv2D(128, (3, 3), strides=(2, 2), padding='same',kernel_initializer=gaussian_weight_init))
    model.add(tf.keras.layers.LeakyReLU(alpha=0.2))
    # downsample
    model.add(tf.keras.layers.Conv2D(128, (3, 3), strides=(2, 2), padding='same',kernel_initializer=gaussian_weight_init))
    model.add(tf.keras.layers.LeakyReLU(alpha=0.2))
    # classifier
    model.add(tf.keras.layers.Flatten())
    model.add(tf.keras.layers.Dropout(0.4))
    model.add(tf.keras.layers.Dense(1, activation='sigmoid'))
    # compile model
    opt = tf.keras.optimizers.Adam(lr=0.0002, beta_1=0.5)
    model.compile(loss='binary_crossentropy', optimizer=opt, metrics=['accuracy'])
    return model

# define the standalone generator model
def defineGen(latent_dim):
    model = tf.keras.models.Sequential()
    # foundation for 4x4 image
    n_nodes = 256 * 4 * 4
    model.add(tf.keras.layers.Dense(n_nodes, input_dim=latent_dim))
    model.add(tf.keras.layers.BatchNormalization(momentum=0.9))
    model.add(tf.keras.layers.LeakyReLU(alpha=0.2))
    model.add(tf.keras.layers.Reshape((4, 4, 256)))
    # upsample to 8x8
    model.add(tf.keras.layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding='same',kernel_initializer=gaussian_weight_init))
    model.add(tf.keras.layers.BatchNormalization(momentum=0.9))
    model.add(tf.keras.layers.LeakyReLU(alpha=0.2))
    # upsample to 16x16
    model.add(tf.keras.layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding='same',kernel_initializer=gaussian_weight_init))
    model.add(tf.keras.layers.BatchNormalization(momentum=0.9))
    model.add(tf.keras.layers.LeakyReLU(alpha=0.2))
    # upsample to 32x32
    model.add(tf.keras.layers.Conv2DTranspose(64, (4, 4), strides=(2, 2), padding='same',kernel_initializer=gaussian_weight_init))
    model.add(tf.keras.layers.BatchNormalization(momentum=0.9))
    model.add(tf.keras.layers.LeakyReLU(alpha=0.2))
    # output layer
    model.add(tf.keras.layers.Conv2D(3, (3, 3), activation='tanh', padding='same'))
    return model

Storing losses for plotting later

In [16]:
#Store losses in list
d1Losses=[]
d2Losses=[]
gLosses=[]

Learning rate scheduler

During training we can change the learning rate to become lower based on the epoch count. This allows for more fine grained learning nearer to the end. Like taking smaller steps down the mountain the lower we go during gradient descent.

In [19]:
#Learning rate schedulers
def schedulerDisc(epoch):
    if epoch < 30:
        return 0.00015
    elif epoch < 60:
        return 0.00012
    elif epoch < 120:
        return 0.0001
    else:
        return 0.00007

def schedulerGen(epoch):
    if epoch < 30:
        return 0.0002
    elif epoch < 60:
        return 0.00018
    elif epoch < 120:
        return 0.00015
    else:
        return 0.0001
In [23]:
# define the combined generator and discriminator model, for updating the generator
def defineGAN(genModel, discModel):
    # make weights in the discriminator not trainable
    discModel.trainable = False
    # connect them
    model = tf.keras.models.Sequential()
    # add generator
    model.add(genModel)
    # add the discriminator
    model.add(discModel)
    # compile model
    opt = tf.keras.optimizers.Adam(lr=0.0002,beta_1=0.5)
    model.compile(loss='binary_crossentropy', optimizer=opt)
    return model

# train the generator and discriminator
def trainGAN(genModel, discModel, gan_model, dataset,latentDim, n_epochs=200, n_batch=128,fid_frequency=50,sumFrequency=2,name='model'):
    bat_per_epo = int(dataset.shape[0] / n_batch)
    half_batch = int(n_batch / 2)
    bestFID=200 #just leave as a high number first
    iter=datagen.flow(dataset,y=None,batch_size=half_batch) #half batch size
    # manually enumerate epochs
    for i in range(n_epochs):
        
        #set learning rates via schedulers
        K.set_value(discModel.optimizer.learning_rate,schedulerDisc(i))
        K.set_value(gan_model.optimizer.learning_rate,schedulerGen(i))
        
        
        
        # enumerate batches over the training set
        for j in range(bat_per_epo):
            # get randomly selected augmented samples
            X_real=iter.next()
            y_real = np.ones((len(X_real), 1)) #get 'real' labels
            #smoothing labels
            y_real=smooth_positive_labels(y_real)
            # update discriminator model weights
            d_loss1, _ = discModel.train_on_batch(X_real, y_real)
            # generate 'fake' examples
            X_fake, y_fake = generateFakeSamples(genModel, latentDim, half_batch)
            # update discriminator model weights
            d_loss2, _ = discModel.train_on_batch(X_fake, y_fake)
            # prepare points in latent space as input for the generator
            X_gan = generateLatentPoints(latentDim, n_batch)
            # create inverted labels for the fake samples
            y_gan = np.ones((n_batch, 1))
            # update the generator via the discriminator's error
            g_loss = gan_model.train_on_batch(X_gan, y_gan)
            
        
            #progress bar per epoch
            if j%25==0 or j==bat_per_epo-1:
                
                dLR_sf = round(discModel.optimizer.learning_rate.numpy() / 10 ** math.floor(math.log10(discModel.optimizer.learning_rate.numpy())),1) * 10 ** math.floor(math.log10(discModel.optimizer.learning_rate.numpy()))
                gLR_sf = round(gan_model.optimizer.learning_rate.numpy() / 10 ** math.floor(math.log10(gan_model.optimizer.learning_rate.numpy())),1) * 10 ** math.floor(math.log10(gan_model.optimizer.learning_rate.numpy()))
                if j==bat_per_epo-1: #if last step
                    #updating losses lists
                    d1Losses.append(d_loss1)
                    d2Losses.append(d_loss2)
                    gLosses.append(g_loss)
                    print(f'>epoch{i+1}, {j+1}/{bat_per_epo}, dLR:{dLR_sf}, gLR:{gLR_sf}, d1={d_loss1:.3f}, d2={d_loss2:.3f}, g={g_loss:.3f}')
                else:
                    print(f'>epoch{i+1}, {j+1}/{bat_per_epo}, dLR:{dLR_sf}, gLR:{gLR_sf}, d1={d_loss1:.3f}, d2={d_loss2:.3f}, g={g_loss:.3f}', 
                          end='\r')
        #compute FID for every interval and last epoch
        if (i>0 and (i+1)%fid_frequency==0) or (i+1)==n_epochs:
            updated_gen_model = tf.keras.models.Sequential()
            updated_gen_model.add(gan_model.layers[0])
            # Set the weights of the new model to be the same as the generator portion of the gan_model
            updated_gen_model.set_weights(gan_model.layers[0].get_weights())
            FID=computeFID(updated_gen_model,latentDim)
            if FID<bestFID:
                print('saving best model')
                bestFID=FID
                #save model
                updated_gen_model = tf.keras.models.Sequential()
                updated_gen_model.add(gan_model.layers[0])
                # Set the weights of the new model to be the same as the generator portion of the gan_model
                updated_gen_model.set_weights(gan_model.layers[0].get_weights())
                updated_gen_model.save('models/model3.h5')
                
        # evaluate the model performance every few epochs
        summarisePerf2(i, genModel, discModel, dataset, latentDim)
    #return losses lists
    return [d1Losses,d2Losses,gLosses]
In [24]:
# size of the latent space
latent_dim = 128
# create the discriminator
discModel = defineDisc()
# create the generator
genModel = defineGen(latent_dim)
# create the gan
gan_model = defineGAN(genModel, discModel)

print('***************DISCRIMINATOR**************')
discModel.summary()
print('***************GENERATOR**************')
genModel.summary()

gan_model.summary()
***************DISCRIMINATOR**************
Model: "sequential_4"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d_5 (Conv2D)            (None, 32, 32, 128)       3584      
_________________________________________________________________
leaky_re_lu_8 (LeakyReLU)    (None, 32, 32, 128)       0         
_________________________________________________________________
conv2d_6 (Conv2D)            (None, 16, 16, 128)       147584    
_________________________________________________________________
leaky_re_lu_9 (LeakyReLU)    (None, 16, 16, 128)       0         
_________________________________________________________________
conv2d_7 (Conv2D)            (None, 8, 8, 128)         147584    
_________________________________________________________________
leaky_re_lu_10 (LeakyReLU)   (None, 8, 8, 128)         0         
_________________________________________________________________
conv2d_8 (Conv2D)            (None, 4, 4, 128)         147584    
_________________________________________________________________
leaky_re_lu_11 (LeakyReLU)   (None, 4, 4, 128)         0         
_________________________________________________________________
flatten_1 (Flatten)          (None, 2048)              0         
_________________________________________________________________
dropout_1 (Dropout)          (None, 2048)              0         
_________________________________________________________________
dense_2 (Dense)              (None, 1)                 2049      
=================================================================
Total params: 448,385
Trainable params: 0
Non-trainable params: 448,385
_________________________________________________________________
***************GENERATOR**************
Model: "sequential_5"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_3 (Dense)              (None, 4096)              528384    
_________________________________________________________________
batch_normalization_4 (Batch (None, 4096)              16384     
_________________________________________________________________
leaky_re_lu_12 (LeakyReLU)   (None, 4096)              0         
_________________________________________________________________
reshape_1 (Reshape)          (None, 4, 4, 256)         0         
_________________________________________________________________
conv2d_transpose_3 (Conv2DTr (None, 8, 8, 128)         524416    
_________________________________________________________________
batch_normalization_5 (Batch (None, 8, 8, 128)         512       
_________________________________________________________________
leaky_re_lu_13 (LeakyReLU)   (None, 8, 8, 128)         0         
_________________________________________________________________
conv2d_transpose_4 (Conv2DTr (None, 16, 16, 128)       262272    
_________________________________________________________________
batch_normalization_6 (Batch (None, 16, 16, 128)       512       
_________________________________________________________________
leaky_re_lu_14 (LeakyReLU)   (None, 16, 16, 128)       0         
_________________________________________________________________
conv2d_transpose_5 (Conv2DTr (None, 32, 32, 64)        131136    
_________________________________________________________________
batch_normalization_7 (Batch (None, 32, 32, 64)        256       
_________________________________________________________________
leaky_re_lu_15 (LeakyReLU)   (None, 32, 32, 64)        0         
_________________________________________________________________
conv2d_9 (Conv2D)            (None, 32, 32, 3)         1731      
=================================================================
Total params: 1,465,603
Trainable params: 1,456,771
Non-trainable params: 8,832
_________________________________________________________________
Model: "sequential_6"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
sequential_5 (Sequential)    (None, 32, 32, 3)         1465603   
_________________________________________________________________
sequential_4 (Sequential)    (None, 1)                 448385    
=================================================================
Total params: 1,913,988
Trainable params: 1,456,771
Non-trainable params: 457,217
_________________________________________________________________
In [25]:
%%time
# load image data
x_train = loadRealSamples()
# train model
d1Loss,d2Loss,gLoss=trainGAN(genModel, discModel, gan_model, x_train,
            latent_dim, n_epochs=200,n_batch=64,fid_frequency=5,sumFrequency=1,name='model3')
>epoch1, 937/937, dLR:0.00015000000000000001, gLR:0.0002, d1=0.604, d2=0.101, g=2.3797
>Accuracy real: 95%, fake: 100%
>epoch198, 937/937, dLR:7.000000000000001e-05, gLR:0.0001, d1=0.570, d2=0.253, g=1.979
>Accuracy real: 69%, fake: 95%
>epoch199, 937/937, dLR:7.000000000000001e-05, gLR:0.0001, d1=0.621, d2=0.327, g=1.781
>Accuracy real: 75%, fake: 98%
>epoch200, 937/937, dLR:7.000000000000001e-05, gLR:0.0001, d1=0.497, d2=0.307, g=1.744

Updating FID with generated images: Batch77

FID: 47.8980598449707
>Accuracy real: 73%, fake: 97%
CPU times: user 4h 7min 8s, sys: 4h 1min 7s, total: 8h 8min 15s
Wall time: 3h 15min 21s

Plot training

In [26]:
epochs=range(len(d1Loss))
plt.figure(figsize=(15,10))
plt.plot(epochs,d1Loss)
plt.plot(epochs,d2Loss)
plt.plot(epochs,gLoss)
plt.title('Model 3 Training')
plt.ylabel('Loss')
plt.xlabel('epoch')
plt.legend(['d1','d2','g'],loc='upper left')
plt.show()

Test best model on Full Training Set for FID

Previously in training we were calculating FID on the 10K test set of images. This is quite efficient as it is fast to compute but it is not the most accurate way to calculate FID. We shall now calculate on the full set of 50K+10K=60K images since it is what we trained on too.

In [15]:
#load
bestModel=tf.keras.models.load_model('models/model3.h5',custom_objects={'gaussian_weight_init': gaussian_weight_init})
bestModel.summary()
WARNING:tensorflow:No training configuration found in the save file, so the model was *not* compiled. Compile it manually.
Model: "sequential_93"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
sequential_64 (Sequential)   (None, 32, 32, 3)         2971523   
=================================================================
Total params: 2,971,523
Trainable params: 2,966,531
Non-trainable params: 4,992
_________________________________________________________________
In [23]:
from torchmetrics.image.fid import FrechetInceptionDistance
import torch
fid=FrechetInceptionDistance(reset_real_features=False).to('cuda')
fid.reset()
real_images = loadRealSamples()
real_images = real_images.astype(np.float32)
real_images=torch.from_numpy(np.array(real_images))

#change dtypes, scale back first
real_images=(real_images+1)*(255/2)
real_images=real_images.type(torch.uint8)
#from 50,000x32x32x3 to 50,000x3x32x32
real_images = real_images.permute(0, 3, 1, 2)
real_images = real_images.to('cuda')
#COMPUTE FID

# create a dataloader with batch size 128
dataloader = torch.utils.data.DataLoader(real_images, batch_size=128, shuffle=True,
                                          drop_last=True
                                          )


# update the FID with real images
for batch,images in enumerate(dataloader):
    print(f'Updating with real images now...Batch{batch}',end='\r')
    fid.update(images, real=True)

print(f'\nFinished! Ready to take in generated image features')

def computeFID(generator,latentDim):
    fid.reset()
    #PREPARE DATA
    gen_images = generateFakeSamples(generator,latentDim,60000)[0]
#     gen_images=np.reshape(gen_images,(50000,3,32,32))
    gen_images = gen_images.astype(np.float32)
    gen_images = torch.from_numpy(np.array(gen_images))
    gen_images=(gen_images+1)*(255/2)
    gen_images=gen_images.type(torch.uint8)
    gen_images = gen_images.permute(0, 3, 1, 2)
    gen_images = gen_images.to('cuda')

    # create a dataloader with batch size 128
    dataloader_gen = torch.utils.data.DataLoader(gen_images, batch_size=128, shuffle=True,
                                            # pin_memory=True,
                                             drop_last=True)
    
    print('')
    # update the FID with generated images
    for batch,images in enumerate(dataloader_gen):
        print(f'Updating FID with generated images: Batch{batch}',end='\r')
        fid.update(images, real=False)
    
    print('\n')
    score = fid.compute()
    print(f'FID: {score.item()}')
    return score
Updating with real images now...Batch467
Finished! Ready to take in generated image features
In [22]:
FID=computeFID(bestModel,100)
Updating FID with generated images: Batch 467

FID: 32.12918853759766

Summary

Our FID had a nice improvement to 32. From the loss graph we can see that near the end of training generator loss was slightly increasing but FID was still remaining stable/improving. So learning was still taking place at that time.

Model 4: Conditional DCGAN (cDCGAN)

Our DCGAN gave good performance but in a real life scenario being able to generate a random image from such a broad dataset like CIFAR10 is not very useful. If we can specify which class we want to generate like a car or a plane, that will become very useful. Conditional DCGAN is a variant of DCGANs that generates new data samples based on a given condition, such as in this case, class labels. They are more challenging to train because the generator has to take into account the additional condition and generate samples that conform to it. Additionally, the discriminator has to be able to distinguish between real and fake samples with respect to the condition, which increases its complexity and the difficulty of training. However, the benefits of being able to generate class-conditioned samples often outweigh the additional challenges.

In [25]:
# !pip install -q git+https://github.com/tensorflow/docs
# !pip install imageio
Requirement already satisfied: imageio in /opt/conda/lib/python3.8/site-packages (2.9.0)
Requirement already satisfied: pillow in /opt/conda/lib/python3.8/site-packages (from imageio) (7.2.0)
Requirement already satisfied: numpy in /opt/conda/lib/python3.8/site-packages (from imageio) (1.19.1)
In [24]:
from tensorflow.keras import layers
from tensorflow_docs.vis import embed
import imageio
In [12]:
from tensorflow import keras
In [15]:
batch_size = 64
num_channels = 3
num_classes = 10
image_size = 32
latent_dim = 128
In [16]:
all_digits=loadRealSamples()
all_labels=loadRealLabels()
print(all_digits.shape)
print(all_labels.shape)
(60000, 32, 32, 3)
(60000, 10)
In [17]:
# Create tf.data.Dataset.
dataset = tf.data.Dataset.from_tensor_slices((all_digits, all_labels))
dataset = dataset.shuffle(buffer_size=1024).batch(batch_size)
In [18]:
generator_in_channels = latent_dim + num_classes
discriminator_in_channels = num_channels + num_classes
print(generator_in_channels, discriminator_in_channels)
138 13
In [19]:
# Create the discriminator.
discriminator = tf.keras.Sequential(
    [
        layers.InputLayer((32, 32,discriminator_in_channels)),
        layers.Conv2D(64, (5,5),strides=2,  padding="same"),
        layers.LeakyReLU(alpha=0.1),
        layers.Conv2D(128, (5,5), strides=(2, 2), padding="same"),
        layers.LeakyReLU(alpha=0.1),
        layers.Conv2D(128, (5,5), strides=(2, 2), padding="same"),
        layers.LeakyReLU(alpha=0.1),
        layers.Conv2D(256, (5,5), strides=(2, 2), padding="same"),
        layers.LeakyReLU(alpha=0.1),
        layers.Flatten(),
        layers.Dense(256,activation='relu'),
        layers.Dropout(0.3),
        layers.Dense(1)
    ],
    name="discriminator",
)

# Create the generator.
generator = keras.Sequential(
    [
        layers.InputLayer((generator_in_channels,)),
        # We want to generate 128 + num_classes coefficients to reshape into a
        # 7x7x(128 + num_classes) map.
        layers.Dense(4 * 4 * 256),
#         layers.BatchNormalization(momentum=0.9),
        layers.LeakyReLU(alpha=0.1),
        layers.Reshape((4,4, 256)),

        layers.Conv2DTranspose(256, (5,5), strides=(2, 2), padding="same"),
#         layers.BatchNormalization(momentum=0.9),
        layers.LeakyReLU(alpha=0.1),
        
        layers.Conv2DTranspose(128, (5,5), strides=(2, 2), padding="same"),
#         layers.BatchNormalization(momentum=0.9),
        layers.LeakyReLU(alpha=0.1),

        layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding="same"),
#         layers.BatchNormalization(momentum=0.9),
        layers.LeakyReLU(alpha=0.1),

        layers.Conv2D(3, (4,4), padding="same", activation="tanh"),
    ],
    name="generator",
)
discriminator.summary()
generator.summary()
Model: "discriminator"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d (Conv2D)              (None, 16, 16, 64)        20864     
_________________________________________________________________
leaky_re_lu (LeakyReLU)      (None, 16, 16, 64)        0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 8, 8, 128)         204928    
_________________________________________________________________
leaky_re_lu_1 (LeakyReLU)    (None, 8, 8, 128)         0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 4, 4, 128)         409728    
_________________________________________________________________
leaky_re_lu_2 (LeakyReLU)    (None, 4, 4, 128)         0         
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 2, 2, 256)         819456    
_________________________________________________________________
leaky_re_lu_3 (LeakyReLU)    (None, 2, 2, 256)         0         
_________________________________________________________________
flatten (Flatten)            (None, 1024)              0         
_________________________________________________________________
dense (Dense)                (None, 256)               262400    
_________________________________________________________________
dropout (Dropout)            (None, 256)               0         
_________________________________________________________________
dense_1 (Dense)              (None, 1)                 257       
=================================================================
Total params: 1,717,633
Trainable params: 1,717,633
Non-trainable params: 0
_________________________________________________________________
Model: "generator"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_2 (Dense)              (None, 4096)              569344    
_________________________________________________________________
leaky_re_lu_4 (LeakyReLU)    (None, 4096)              0         
_________________________________________________________________
reshape (Reshape)            (None, 4, 4, 256)         0         
_________________________________________________________________
conv2d_transpose (Conv2DTran (None, 8, 8, 256)         1638656   
_________________________________________________________________
leaky_re_lu_5 (LeakyReLU)    (None, 8, 8, 256)         0         
_________________________________________________________________
conv2d_transpose_1 (Conv2DTr (None, 16, 16, 128)       819328    
_________________________________________________________________
leaky_re_lu_6 (LeakyReLU)    (None, 16, 16, 128)       0         
_________________________________________________________________
conv2d_transpose_2 (Conv2DTr (None, 32, 32, 64)        204864    
_________________________________________________________________
leaky_re_lu_7 (LeakyReLU)    (None, 32, 32, 64)        0         
_________________________________________________________________
conv2d_4 (Conv2D)            (None, 32, 32, 3)         3075      
=================================================================
Total params: 3,235,267
Trainable params: 3,235,267
Non-trainable params: 0
_________________________________________________________________
In [13]:
#Conditional GAN class, inherit from keras Model class
class ConditionalGAN(keras.Model):
    def __init__(self, discriminator, generator, latent_dim):
        super().__init__()
        self.discriminator = discriminator
        self.generator = generator
        self.latent_dim = latent_dim
        self.gen_loss_tracker = keras.metrics.Mean(name="generator_loss")
        self.disc_loss_tracker = keras.metrics.Mean(name="discriminator_loss")

    @property
    def metrics(self):
        return [self.gen_loss_tracker, self.disc_loss_tracker]

    def compile(self, d_optimizer, g_optimizer, loss_fn):
        super().compile()
        self.d_optimizer = d_optimizer
        self.g_optimizer = g_optimizer
        self.loss_fn = loss_fn

    def train_step(self, data):
        # Unpack the data.
        real_images, one_hot_labels = data

        # Add dummy dimensions to the labels so that they can be concatenated with
        image_one_hot_labels = one_hot_labels[:, :, None, None]
        image_one_hot_labels = tf.repeat(
            image_one_hot_labels, repeats=[image_size * image_size]
        )
        image_one_hot_labels = tf.reshape(
            image_one_hot_labels, (-1, image_size, image_size, num_classes)
        )

        # Sample random points in the latent space and concatenate labels
        batch_size = tf.shape(real_images)[0]
        random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))
        random_vector_labels = tf.concat(
            [random_latent_vectors, one_hot_labels], axis=1
        )

        # Decode the noise (guided by labels) to fake images.
        generated_images = self.generator(random_vector_labels)

        # Combine them with real images
        fake_image_and_labels = tf.concat([generated_images, image_one_hot_labels], -1)
        real_image_and_labels = tf.concat([real_images, image_one_hot_labels], -1)
        combined_images = tf.concat(
            [fake_image_and_labels, real_image_and_labels], axis=0
        )

        # Assemble labels discriminating real from fake images.
        labels = tf.concat(
            [tf.ones((batch_size, 1)), tf.zeros((batch_size, 1))], axis=0
        )

        # Train the discriminator.
        with tf.GradientTape() as tape:
            predictions = self.discriminator(combined_images)
            d_loss = self.loss_fn(labels, predictions)
        grads = tape.gradient(d_loss, self.discriminator.trainable_weights)
        self.d_optimizer.apply_gradients(
            zip(grads, self.discriminator.trainable_weights)
        )

        # Sample random points in the latent space.
        random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))
        random_vector_labels = tf.concat(
            [random_latent_vectors, one_hot_labels], axis=1
        )

        # Assemble labels that say "all real images".
        misleading_labels = tf.zeros((batch_size, 1))

        # Train the generator 
        with tf.GradientTape() as tape:
            fake_images = self.generator(random_vector_labels)
            fake_image_and_labels = tf.concat([fake_images, image_one_hot_labels], -1)
            predictions = self.discriminator(fake_image_and_labels)
            g_loss = self.loss_fn(misleading_labels, predictions)
        grads = tape.gradient(g_loss, self.generator.trainable_weights)
        self.g_optimizer.apply_gradients(zip(grads, self.generator.trainable_weights))

        # Monitor loss.
        self.gen_loss_tracker.update_state(g_loss)
        self.disc_loss_tracker.update_state(d_loss)
        return {
            "g_loss": self.gen_loss_tracker.result(),
            "d_loss": self.disc_loss_tracker.result(),
        }
In [62]:
cond_gan = ConditionalGAN(
    discriminator=discriminator, generator=generator, latent_dim=latent_dim
)
cond_gan.compile(
    d_optimizer=keras.optimizers.Adam(learning_rate=0.0002),
    g_optimizer=keras.optimizers.Adam(learning_rate=0.0002),
    loss_fn=keras.losses.BinaryCrossentropy(from_logits=True),
)

cond_gan.fit(dataset, epochs=200)
Epoch 1/200
938/938 [==============================] - 21s 19ms/step - g_loss: 8.4527 - d_loss: 0.2099
Epoch 2/200
938/938 [==============================] - 18s 19ms/step - g_loss: 3.0643 - d_loss: 0.2604
Epoch 3/200
938/938 [==============================] - 18s 19ms/step - g_loss: 3.2579 - d_loss: 0.2389
Epoch 4/200
938/938 [==============================] - 18s 19ms/step - g_loss: 3.2516 - d_loss: 0.2981
Epoch 5/200
938/938 [==============================] - 18s 19ms/step - g_loss: 2.8838 - d_loss: 0.3003
Epoch 6/200
938/938 [==============================] - 18s 19ms/step - g_loss: 2.6836 - d_loss: 0.3150
Epoch 7/200
938/938 [==============================] - 18s 19ms/step - g_loss: 2.6404 - d_loss: 0.3679
Epoch 8/200
938/938 [==============================] - 18s 19ms/step - g_loss: 2.2574 - d_loss: 0.4137
Epoch 9/200
938/938 [==============================] - 18s 19ms/step - g_loss: 2.2382 - d_loss: 0.3979
Epoch 10/200
938/938 [==============================] - 18s 19ms/step - g_loss: 2.3395 - d_loss: 0.4132
Epoch 11/200
938/938 [==============================] - 18s 19ms/step - g_loss: 1.6866 - d_loss: 0.4867
Epoch 12/200
938/938 [==============================] - 18s 19ms/step - g_loss: 1.8253 - d_loss: 0.4781
Epoch 13/200
938/938 [==============================] - 18s 19ms/step - g_loss: 1.6806 - d_loss: 0.4686
Epoch 14/200
938/938 [==============================] - 18s 19ms/step - g_loss: 1.4501 - d_loss: 0.5367
Epoch 15/200
938/938 [==============================] - 18s 19ms/step - g_loss: 1.5582 - d_loss: 0.4866
Epoch 16/200
938/938 [==============================] - 18s 19ms/step - g_loss: 1.5868 - d_loss: 0.4846
Epoch 17/200
938/938 [==============================] - 18s 19ms/step - g_loss: 1.6529 - d_loss: 0.5188
Epoch 18/200
938/938 [==============================] - 18s 19ms/step - g_loss: 1.5661 - d_loss: 0.5191
Epoch 19/200
938/938 [==============================] - 18s 19ms/step - g_loss: 1.4432 - d_loss: 0.5296
Epoch 20/200
938/938 [==============================] - 18s 19ms/step - g_loss: 1.3418 - d_loss: 0.5575
Epoch 21/200
938/938 [==============================] - 18s 19ms/step - g_loss: 1.4006 - d_loss: 0.5336
Epoch 22/200
938/938 [==============================] - 18s 19ms/step - g_loss: 1.3024 - d_loss: 0.5459
Epoch 23/200
938/938 [==============================] - 18s 19ms/step - g_loss: 1.3596 - d_loss: 0.5358
Epoch 24/200
938/938 [==============================] - 18s 19ms/step - g_loss: 1.3528 - d_loss: 0.5381
Epoch 25/200
938/938 [==============================] - 18s 19ms/step - g_loss: 1.3929 - d_loss: 0.5581
Epoch 26/200
938/938 [==============================] - 18s 19ms/step - g_loss: 1.2055 - d_loss: 0.5554
Epoch 27/200
938/938 [==============================] - 18s 19ms/step - g_loss: 1.2325 - d_loss: 0.5620
Epoch 28/200
938/938 [==============================] - 18s 19ms/step - g_loss: 1.2481 - d_loss: 0.5695
Epoch 29/200
938/938 [==============================] - 18s 19ms/step - g_loss: 1.3550 - d_loss: 0.5551
Epoch 30/200
938/938 [==============================] - 18s 19ms/step - g_loss: 1.4210 - d_loss: 0.5452
Epoch 31/200
938/938 [==============================] - 18s 19ms/step - g_loss: 1.2895 - d_loss: 0.5550
Epoch 32/200
938/938 [==============================] - 18s 19ms/step - g_loss: 1.3359 - d_loss: 0.5407
Epoch 33/200
938/938 [==============================] - 18s 19ms/step - g_loss: 1.3889 - d_loss: 0.5335
Epoch 34/200
938/938 [==============================] - 18s 19ms/step - g_loss: 1.2359 - d_loss: 0.5446
Epoch 35/200
938/938 [==============================] - 18s 19ms/step - g_loss: 1.3649 - d_loss: 0.5518
Epoch 36/200
938/938 [==============================] - 18s 19ms/step - g_loss: 1.2065 - d_loss: 0.5661
Epoch 37/200
938/938 [==============================] - 18s 19ms/step - g_loss: 1.3573 - d_loss: 0.5717
Epoch 38/200
938/938 [==============================] - 18s 19ms/step - g_loss: 1.2556 - d_loss: 0.5644
Epoch 39/200
938/938 [==============================] - 18s 19ms/step - g_loss: 1.2546 - d_loss: 0.5532
Epoch 40/200
938/938 [==============================] - 18s 19ms/step - g_loss: 1.1211 - d_loss: 0.5834
Epoch 41/200
938/938 [==============================] - 18s 19ms/step - g_loss: 1.2240 - d_loss: 0.5755
Epoch 42/200
938/938 [==============================] - 18s 19ms/step - g_loss: 1.2033 - d_loss: 0.5863
Epoch 43/200
938/938 [==============================] - 18s 19ms/step - g_loss: 1.1615 - d_loss: 0.5751
Epoch 44/200
938/938 [==============================] - 18s 19ms/step - g_loss: 1.1775 - d_loss: 0.5761
Epoch 45/200
938/938 [==============================] - 18s 19ms/step - g_loss: 1.2063 - d_loss: 0.5840
Epoch 46/200
938/938 [==============================] - 18s 19ms/step - g_loss: 1.2181 - d_loss: 0.5908
Epoch 47/200
938/938 [==============================] - 18s 19ms/step - g_loss: 1.1410 - d_loss: 0.5803
Epoch 48/200
938/938 [==============================] - 18s 19ms/step - g_loss: 1.1677 - d_loss: 0.5720
Epoch 49/200
938/938 [==============================] - 18s 19ms/step - g_loss: 1.1744 - d_loss: 0.5852
Epoch 50/200
938/938 [==============================] - 18s 19ms/step - g_loss: 1.1251 - d_loss: 0.5745
Epoch 51/200
938/938 [==============================] - 18s 19ms/step - g_loss: 1.4512 - d_loss: 0.5797
Epoch 52/200
938/938 [==============================] - 18s 19ms/step - g_loss: 1.1498 - d_loss: 0.5731
Epoch 53/200
938/938 [==============================] - 18s 19ms/step - g_loss: 1.0951 - d_loss: 0.5921
Epoch 54/200
938/938 [==============================] - 18s 19ms/step - g_loss: 1.1488 - d_loss: 0.5880
Epoch 55/200
938/938 [==============================] - 18s 19ms/step - g_loss: 1.2283 - d_loss: 0.5698
Epoch 56/200
938/938 [==============================] - 18s 19ms/step - g_loss: 1.1628 - d_loss: 0.5812
Epoch 57/200
938/938 [==============================] - 18s 19ms/step - g_loss: 1.1452 - d_loss: 0.5848
Epoch 58/200
938/938 [==============================] - 18s 19ms/step - g_loss: 1.1566 - d_loss: 0.5728
Epoch 59/200
938/938 [==============================] - 18s 19ms/step - g_loss: 1.1012 - d_loss: 0.5933
Epoch 60/200
938/938 [==============================] - 18s 19ms/step - g_loss: 1.2044 - d_loss: 0.5997
Epoch 61/200
938/938 [==============================] - 18s 19ms/step - g_loss: 1.0574 - d_loss: 0.6163
Epoch 62/200
938/938 [==============================] - 18s 19ms/step - g_loss: 1.0956 - d_loss: 0.5890
Epoch 63/200
938/938 [==============================] - 18s 19ms/step - g_loss: 1.1639 - d_loss: 0.5790
Epoch 64/200
938/938 [==============================] - 18s 19ms/step - g_loss: 1.1255 - d_loss: 0.5896
Epoch 65/200
938/938 [==============================] - 18s 19ms/step - g_loss: 1.1108 - d_loss: 0.5982
Epoch 66/200
938/938 [==============================] - 18s 19ms/step - g_loss: 1.1028 - d_loss: 0.5859
Epoch 67/200
938/938 [==============================] - 18s 19ms/step - g_loss: 1.1286 - d_loss: 0.5952
Epoch 68/200
938/938 [==============================] - 18s 19ms/step - g_loss: 1.0770 - d_loss: 0.5779
Epoch 69/200
938/938 [==============================] - 18s 19ms/step - g_loss: 1.1421 - d_loss: 0.5998
Epoch 70/200
938/938 [==============================] - 18s 19ms/step - g_loss: 1.0469 - d_loss: 0.5955
Epoch 71/200
938/938 [==============================] - 18s 19ms/step - g_loss: 1.0912 - d_loss: 0.6036
Epoch 72/200
938/938 [==============================] - 18s 19ms/step - g_loss: 1.0100 - d_loss: 0.6158
Epoch 73/200
938/938 [==============================] - 18s 19ms/step - g_loss: 1.0606 - d_loss: 0.5983
Epoch 74/200
938/938 [==============================] - 18s 19ms/step - g_loss: 1.0476 - d_loss: 0.6036
Epoch 75/200
938/938 [==============================] - 18s 19ms/step - g_loss: 1.0230 - d_loss: 0.5972
Epoch 76/200
938/938 [==============================] - 18s 19ms/step - g_loss: 1.0259 - d_loss: 0.6232
Epoch 77/200
938/938 [==============================] - 18s 19ms/step - g_loss: 1.1323 - d_loss: 0.5829
Epoch 78/200
938/938 [==============================] - 18s 19ms/step - g_loss: 1.0102 - d_loss: 0.6074
Epoch 79/200
938/938 [==============================] - 18s 19ms/step - g_loss: 1.1344 - d_loss: 0.5810
Epoch 80/200
938/938 [==============================] - 18s 19ms/step - g_loss: 1.0546 - d_loss: 0.6029
Epoch 81/200
938/938 [==============================] - 18s 19ms/step - g_loss: 0.9659 - d_loss: 0.6193
Epoch 82/200
938/938 [==============================] - 18s 19ms/step - g_loss: 1.0435 - d_loss: 0.6028
Epoch 83/200
938/938 [==============================] - 18s 19ms/step - g_loss: 1.0365 - d_loss: 0.6147
Epoch 84/200
938/938 [==============================] - 18s 19ms/step - g_loss: 1.0733 - d_loss: 0.5933
Epoch 85/200
938/938 [==============================] - 18s 19ms/step - g_loss: 1.1464 - d_loss: 0.6052
Epoch 86/200
938/938 [==============================] - 18s 19ms/step - g_loss: 1.0592 - d_loss: 0.5804
Epoch 87/200
938/938 [==============================] - 18s 19ms/step - g_loss: 1.0518 - d_loss: 0.6028
Epoch 88/200
938/938 [==============================] - 18s 19ms/step - g_loss: 1.1094 - d_loss: 0.5866
Epoch 89/200
938/938 [==============================] - 18s 19ms/step - g_loss: 1.0128 - d_loss: 0.6049
Epoch 90/200
938/938 [==============================] - 18s 19ms/step - g_loss: 0.9947 - d_loss: 0.6133
Epoch 91/200
938/938 [==============================] - 18s 19ms/step - g_loss: 1.0543 - d_loss: 0.6068
Epoch 92/200
938/938 [==============================] - 18s 19ms/step - g_loss: 0.9994 - d_loss: 0.6001
Epoch 93/200
938/938 [==============================] - 18s 19ms/step - g_loss: 1.0683 - d_loss: 0.6059
Epoch 94/200
938/938 [==============================] - 18s 19ms/step - g_loss: 1.0126 - d_loss: 0.5957
Epoch 95/200
938/938 [==============================] - 18s 19ms/step - g_loss: 1.0235 - d_loss: 0.5987
Epoch 96/200
938/938 [==============================] - 18s 19ms/step - g_loss: 1.0531 - d_loss: 0.5956
Epoch 97/200
938/938 [==============================] - 18s 19ms/step - g_loss: 0.9993 - d_loss: 0.6098
Epoch 98/200
938/938 [==============================] - 18s 19ms/step - g_loss: 1.0160 - d_loss: 0.6041
Epoch 99/200
938/938 [==============================] - 18s 19ms/step - g_loss: 1.0152 - d_loss: 0.5939
Epoch 100/200
938/938 [==============================] - 18s 19ms/step - g_loss: 1.0615 - d_loss: 0.5902
Epoch 101/200
938/938 [==============================] - 18s 19ms/step - g_loss: 1.0134 - d_loss: 0.5900
Epoch 102/200
938/938 [==============================] - 18s 19ms/step - g_loss: 1.0652 - d_loss: 0.5953
Epoch 103/200
938/938 [==============================] - 18s 19ms/step - g_loss: 1.0489 - d_loss: 0.5883
Epoch 104/200
938/938 [==============================] - 18s 19ms/step - g_loss: 1.0562 - d_loss: 0.5820
Epoch 105/200
938/938 [==============================] - 18s 19ms/step - g_loss: 1.0810 - d_loss: 0.5839
Epoch 106/200
938/938 [==============================] - 18s 19ms/step - g_loss: 1.0364 - d_loss: 0.5982
Epoch 107/200
938/938 [==============================] - 18s 19ms/step - g_loss: 1.0784 - d_loss: 0.5884
Epoch 108/200
938/938 [==============================] - 18s 19ms/step - g_loss: 1.0664 - d_loss: 0.5813
Epoch 109/200
938/938 [==============================] - 18s 19ms/step - g_loss: 1.0293 - d_loss: 0.5897
Epoch 110/200
938/938 [==============================] - 18s 19ms/step - g_loss: 1.0799 - d_loss: 0.5728
Epoch 111/200
938/938 [==============================] - 18s 19ms/step - g_loss: 1.1001 - d_loss: 0.5671
Epoch 112/200
938/938 [==============================] - 18s 19ms/step - g_loss: 1.1217 - d_loss: 0.5646
Epoch 113/200
938/938 [==============================] - 18s 19ms/step - g_loss: 1.1348 - d_loss: 0.5606
Epoch 114/200
938/938 [==============================] - 18s 19ms/step - g_loss: 1.1564 - d_loss: 0.5488
Epoch 115/200
938/938 [==============================] - 18s 19ms/step - g_loss: 1.2096 - d_loss: 0.5386
Epoch 116/200
938/938 [==============================] - 18s 19ms/step - g_loss: 1.2314 - d_loss: 0.5230
Epoch 117/200
938/938 [==============================] - 18s 19ms/step - g_loss: 1.2743 - d_loss: 0.5237
Epoch 118/200
938/938 [==============================] - 18s 19ms/step - g_loss: 1.3255 - d_loss: 0.5088
Epoch 119/200
938/938 [==============================] - 18s 19ms/step - g_loss: 1.3732 - d_loss: 0.4954
Epoch 120/200
938/938 [==============================] - 18s 19ms/step - g_loss: 1.4225 - d_loss: 0.4812
Epoch 121/200
938/938 [==============================] - 18s 19ms/step - g_loss: 1.5030 - d_loss: 0.4658
Epoch 122/200
938/938 [==============================] - 18s 19ms/step - g_loss: 1.5772 - d_loss: 0.4463
Epoch 123/200
938/938 [==============================] - 18s 19ms/step - g_loss: 1.6727 - d_loss: 0.4282
Epoch 124/200
938/938 [==============================] - 18s 19ms/step - g_loss: 1.7592 - d_loss: 0.4111
Epoch 125/200
938/938 [==============================] - 18s 19ms/step - g_loss: 1.8565 - d_loss: 0.3965
Epoch 126/200
938/938 [==============================] - 18s 19ms/step - g_loss: 1.9907 - d_loss: 0.3737
Epoch 127/200
938/938 [==============================] - 18s 19ms/step - g_loss: 2.1591 - d_loss: 0.3533
Epoch 128/200
938/938 [==============================] - 18s 19ms/step - g_loss: 2.2874 - d_loss: 0.3402
Epoch 129/200
938/938 [==============================] - 18s 19ms/step - g_loss: 2.4132 - d_loss: 0.3245
Epoch 130/200
938/938 [==============================] - 18s 19ms/step - g_loss: 2.5795 - d_loss: 0.3084
Epoch 131/200
938/938 [==============================] - 18s 19ms/step - g_loss: 2.7668 - d_loss: 0.2870
Epoch 132/200
938/938 [==============================] - 18s 19ms/step - g_loss: 2.9065 - d_loss: 0.2730
Epoch 133/200
938/938 [==============================] - 18s 19ms/step - g_loss: 3.0858 - d_loss: 0.2593
Epoch 134/200
938/938 [==============================] - 18s 19ms/step - g_loss: 3.2672 - d_loss: 0.2501
Epoch 135/200
938/938 [==============================] - 18s 19ms/step - g_loss: 3.3826 - d_loss: 0.2397
Epoch 136/200
938/938 [==============================] - 18s 19ms/step - g_loss: 3.4968 - d_loss: 0.2328
Epoch 137/200
938/938 [==============================] - 18s 19ms/step - g_loss: 3.7631 - d_loss: 0.2205
Epoch 138/200
938/938 [==============================] - 18s 19ms/step - g_loss: 3.8385 - d_loss: 0.2127
Epoch 139/200
938/938 [==============================] - 18s 19ms/step - g_loss: 4.0076 - d_loss: 0.1994
Epoch 140/200
938/938 [==============================] - 17s 19ms/step - g_loss: 4.0933 - d_loss: 0.2037
Epoch 141/200
938/938 [==============================] - 18s 19ms/step - g_loss: 4.2088 - d_loss: 0.1938
Epoch 142/200
938/938 [==============================] - 18s 19ms/step - g_loss: 4.3336 - d_loss: 0.1863
Epoch 143/200
938/938 [==============================] - 18s 19ms/step - g_loss: 4.4730 - d_loss: 0.1817
Epoch 144/200
938/938 [==============================] - 18s 19ms/step - g_loss: 4.6006 - d_loss: 0.1720
Epoch 145/200
938/938 [==============================] - 18s 19ms/step - g_loss: 4.6488 - d_loss: 0.1699
Epoch 146/200
938/938 [==============================] - 18s 19ms/step - g_loss: 4.8065 - d_loss: 0.1693
Epoch 147/200
938/938 [==============================] - 18s 19ms/step - g_loss: 4.8686 - d_loss: 0.1664
Epoch 148/200
938/938 [==============================] - 18s 19ms/step - g_loss: 4.9788 - d_loss: 0.1661
Epoch 149/200
938/938 [==============================] - 18s 19ms/step - g_loss: 5.0016 - d_loss: 0.1601
Epoch 150/200
938/938 [==============================] - 18s 19ms/step - g_loss: 5.0481 - d_loss: 0.1615
Epoch 151/200
938/938 [==============================] - 18s 19ms/step - g_loss: 5.2094 - d_loss: 0.1508
Epoch 152/200
938/938 [==============================] - 18s 19ms/step - g_loss: 5.3830 - d_loss: 0.1512
Epoch 153/200
938/938 [==============================] - 18s 19ms/step - g_loss: 5.3900 - d_loss: 0.1485
Epoch 154/200
938/938 [==============================] - 18s 19ms/step - g_loss: 5.4395 - d_loss: 0.1422
Epoch 155/200
938/938 [==============================] - 18s 19ms/step - g_loss: 5.5505 - d_loss: 0.1387
Epoch 156/200
938/938 [==============================] - 18s 19ms/step - g_loss: 5.5679 - d_loss: 0.1368
Epoch 157/200
938/938 [==============================] - 18s 19ms/step - g_loss: 5.6652 - d_loss: 0.1433
Epoch 158/200
938/938 [==============================] - 18s 19ms/step - g_loss: 5.6960 - d_loss: 0.1350
Epoch 159/200
938/938 [==============================] - 18s 19ms/step - g_loss: 5.8526 - d_loss: 0.1314
Epoch 160/200
938/938 [==============================] - 18s 19ms/step - g_loss: 5.8966 - d_loss: 0.1313
Epoch 161/200
938/938 [==============================] - 18s 19ms/step - g_loss: 5.9361 - d_loss: 0.1266
Epoch 162/200
938/938 [==============================] - 18s 19ms/step - g_loss: 6.0354 - d_loss: 0.1271
Epoch 163/200
938/938 [==============================] - 18s 19ms/step - g_loss: 6.1452 - d_loss: 0.1211
Epoch 164/200
938/938 [==============================] - 18s 19ms/step - g_loss: 6.1628 - d_loss: 0.1191
Epoch 165/200
938/938 [==============================] - 18s 19ms/step - g_loss: 6.2219 - d_loss: 0.1188
Epoch 166/200
938/938 [==============================] - 18s 19ms/step - g_loss: 6.2460 - d_loss: 0.1217
Epoch 167/200
938/938 [==============================] - 18s 19ms/step - g_loss: 6.3292 - d_loss: 0.1151
Epoch 168/200
938/938 [==============================] - 18s 19ms/step - g_loss: 6.3941 - d_loss: 0.1190
Epoch 169/200
938/938 [==============================] - 18s 19ms/step - g_loss: 6.4279 - d_loss: 0.1124
Epoch 170/200
938/938 [==============================] - 18s 19ms/step - g_loss: 6.5312 - d_loss: 0.1097
Epoch 171/200
938/938 [==============================] - 18s 19ms/step - g_loss: 6.5116 - d_loss: 0.1127
Epoch 172/200
938/938 [==============================] - 18s 19ms/step - g_loss: 6.5757 - d_loss: 0.1088
Epoch 173/200
938/938 [==============================] - 18s 19ms/step - g_loss: 6.5861 - d_loss: 0.1072
Epoch 174/200
938/938 [==============================] - 18s 19ms/step - g_loss: 6.7037 - d_loss: 0.1138
Epoch 175/200
938/938 [==============================] - 18s 19ms/step - g_loss: 6.7185 - d_loss: 0.1051
Epoch 176/200
938/938 [==============================] - 18s 19ms/step - g_loss: 6.8267 - d_loss: 0.1023
Epoch 177/200
938/938 [==============================] - 18s 19ms/step - g_loss: 6.7998 - d_loss: 0.1039
Epoch 178/200
938/938 [==============================] - 18s 19ms/step - g_loss: 6.8794 - d_loss: 0.1000
Epoch 179/200
938/938 [==============================] - 18s 19ms/step - g_loss: 6.9170 - d_loss: 0.0976
Epoch 180/200
938/938 [==============================] - 18s 19ms/step - g_loss: 6.9915 - d_loss: 0.0970
Epoch 181/200
938/938 [==============================] - 18s 19ms/step - g_loss: 7.0351 - d_loss: 0.1008
Epoch 182/200
938/938 [==============================] - 18s 19ms/step - g_loss: 7.0118 - d_loss: 0.0955
Epoch 183/200
938/938 [==============================] - 18s 19ms/step - g_loss: 7.0932 - d_loss: 0.0947
Epoch 184/200
938/938 [==============================] - 18s 19ms/step - g_loss: 7.2023 - d_loss: 0.0934
Epoch 185/200
938/938 [==============================] - 18s 19ms/step - g_loss: 7.0835 - d_loss: 0.0920
Epoch 186/200
938/938 [==============================] - 18s 19ms/step - g_loss: 7.2153 - d_loss: 0.0918
Epoch 187/200
938/938 [==============================] - 17s 19ms/step - g_loss: 7.1330 - d_loss: 0.0903
Epoch 188/200
938/938 [==============================] - 18s 19ms/step - g_loss: 7.2981 - d_loss: 0.0879
Epoch 189/200
938/938 [==============================] - 18s 19ms/step - g_loss: 7.3779 - d_loss: 0.0897
Epoch 190/200
938/938 [==============================] - 18s 19ms/step - g_loss: 7.4587 - d_loss: 0.0891
Epoch 191/200
938/938 [==============================] - 18s 19ms/step - g_loss: 7.4046 - d_loss: 0.0903
Epoch 192/200
938/938 [==============================] - 18s 19ms/step - g_loss: 7.3697 - d_loss: 0.0916
Epoch 193/200
938/938 [==============================] - 18s 19ms/step - g_loss: 7.4905 - d_loss: 0.0851
Epoch 194/200
938/938 [==============================] - 18s 19ms/step - g_loss: 7.5004 - d_loss: 0.0873
Epoch 195/200
938/938 [==============================] - 18s 19ms/step - g_loss: 7.5865 - d_loss: 0.0867
Epoch 196/200
938/938 [==============================] - 18s 19ms/step - g_loss: 7.5976 - d_loss: 0.0823
Epoch 197/200
938/938 [==============================] - 18s 19ms/step - g_loss: 7.6270 - d_loss: 0.0807
Epoch 198/200
938/938 [==============================] - 17s 19ms/step - g_loss: 7.6080 - d_loss: 0.0816
Epoch 199/200
938/938 [==============================] - 18s 19ms/step - g_loss: 7.6571 - d_loss: 0.0805
Epoch 200/200
938/938 [==============================] - 18s 19ms/step - g_loss: 7.8012 - d_loss: 0.0758
Out[62]:
<keras.callbacks.History at 0x7f9ae039d1c0>
In [71]:
#save model
cond_gan.generator.save('models/cgan.h5')
gen=cond_gan.generator
WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.

Plot images

In [22]:
class_names = ["airplane", "automobile", "bird", "cat", "deer",               
               "dog", "frog", "horse", "ship", "truck"]
# Generate 100 images in total
num_images = 100
num_classes = 10

plt.figure(figsize=(15,15))
# Generate 10 random latent vectors for each class
random_latent_vectors = tf.random.normal(shape=(num_images, 128))

# Generate one-hot labels for each class
one_hot_labels = np.zeros((num_images, num_classes))
for i in range(num_classes):
    start_idx = i * 10
    end_idx = start_idx + 10
    one_hot_labels[start_idx:end_idx, i] = 1

# Concatenate the random latent vectors and one-hot labels
random_vector_labels = tf.concat([random_latent_vectors, one_hot_labels], axis=1)

# Generate images using the CGAN model
imgs = gen.predict(random_vector_labels)

# Scale the images from [-1, 1] to [0, 1]
imgs = (imgs + 1) / 2.0

# Plot the generated images
rowCount=0
for i in range(num_images):
    # Define the subplot
    plt.subplot(num_images // 10, 10, 1 + i)
    # Turn off axis
    # Plot raw pixel data
    if i%10 == 0:
            plt.ylabel(class_names[rowCount], size=10)
            rowCount+=1
    else:
        plt.axis('off')
    plt.imshow(imgs[i])
plt.savefig('CGAN.png')
plt.show()
4/4 [==============================] - 0s 2ms/step

Calculate FID

In [17]:
from torchmetrics.image.fid import FrechetInceptionDistance
import torch
fid=FrechetInceptionDistance(reset_real_features=False).to('cuda')
fid.reset()
real_images = loadRealSamples()
real_images = real_images.astype(np.float32)
real_images=torch.from_numpy(np.array(real_images))

#change dtypes, scale back first
real_images=(real_images+1)*(255/2)
real_images=real_images.type(torch.uint8)
#from 50,000x32x32x3 to 50,000x3x32x32
real_images = real_images.permute(0, 3, 1, 2)
real_images = real_images.to('cuda')
#COMPUTE FID

# create a dataloader with batch size 128
dataloader = torch.utils.data.DataLoader(real_images, batch_size=128, shuffle=True,
                                          drop_last=True
                                          )


# update the FID with real images
for batch,images in enumerate(dataloader):
    print(f'Updating with real images now...Batch{batch}',end='\r')
    fid.update(images, real=True)

print(f'\nFinished! Ready to take in generated image features')
Downloading: "https://github.com/toshas/torch-fidelity/releases/download/v0.2.0/weights-inception-2015-12-05-6726825d.pth" to /root/.cache/torch/hub/checkpoints/weights-inception-2015-12-05-6726825d.pth
Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
170498071/170498071 [==============================] - 7s 0us/step
Updating with real images now...Batch467
Finished! Ready to take in generated image features
In [ ]:
def computeFID(generator,latentDim):
    fid.reset()
    #PREPARE DATA
    random_latent_vectors = tf.random.normal(shape=(60000, 128))
    one_hot_labels=tf.one_hot(tf.cast(tf.range(0,60000), dtype=tf.int32), depth=10)
    random_vector_labels = tf.concat(
        [random_latent_vectors, one_hot_labels], axis=1
    )
    gen_images=generator.predict(random_vector_labels)
#     gen_images=np.reshape(gen_images,(50000,3,32,32))
    gen_images = gen_images.astype(np.float32)
    gen_images = torch.from_numpy(np.array(gen_images))
    gen_images=(gen_images+1)*(255/2)
    gen_images=gen_images.type(torch.uint8)
    gen_images = gen_images.permute(0, 3, 1, 2)
    gen_images = gen_images.to('cuda')

    # create a dataloader with batch size 128
    dataloader_gen = torch.utils.data.DataLoader(gen_images, batch_size=128, shuffle=True,
                                            # pin_memory=True,
                                             drop_last=True)
    
    print('')
    # update the FID with generated images
    for batch,images in enumerate(dataloader_gen):
        print(f'Updating FID with generated images: Batch{batch}',end='\r')
        fid.update(images, real=False)
    
    print('\n')
    score = fid.compute()
    print(f'FID: {score.item()}')
    return score
In [26]:
computeFID(gen,128)
print('')
Updating FID with generated images: Batch467

FID: 41.770751953125

Model 5: WGAN

WGAN (Wasserstein Generative Adversarial Network) uses the Wasserstein distance metric to evaluate the performance of the generator. Using WGAN over regular DCGAN (Deep Convolutional GAN) can give us better potential in our model as it provides a more stable training process and produces higher-quality generated images. To implement this architecture, we must overhaul our entire training loop/architecture

In [7]:
IMG_SHAPE = (32,32,3)
BATCH_SIZE = 512

# Size of the noise vector
noise_dim = 128

# Reshape each sample to (28, 28, 1) and normalize the pixel values in the [-1, 1] range
train_images = loadRealSamples()
In [8]:
def conv_block(
    x,
    filters,
    activation,
    kernel_size=(3, 3),
    strides=(1, 1),
    padding="same",
    use_bias=True,
    use_bn=False,
    use_dropout=False,
    drop_value=0.5,
):
    x = layers.Conv2D(
        filters, kernel_size, strides=strides, padding=padding, use_bias=use_bias
    )(x)
    if use_bn:
        x = layers.BatchNormalization()(x)
    x = activation(x)
    if use_dropout:
        x = layers.Dropout(drop_value)(x)
    return x


def get_discriminator_model():
    img_input = layers.Input(shape=IMG_SHAPE)
    # Zero pad the input to make the input images size to (32, 32, 1).
    x = layers.ZeroPadding2D((2, 2))(img_input)
    x = conv_block(
        x,
        64,
        kernel_size=(5, 5),
        strides=(2, 2),
        use_bn=False,
        use_bias=True,
        activation=layers.LeakyReLU(0.2),
        use_dropout=False,
        drop_value=0.3,
    )
    x = conv_block(
        x,
        128,
        kernel_size=(5, 5),
        strides=(2, 2),
        use_bn=False,
        activation=layers.LeakyReLU(0.2),
        use_bias=True,
        use_dropout=True,
        drop_value=0.3,
    )
    x = conv_block(
        x,
        256,
        kernel_size=(5, 5),
        strides=(2, 2),
        use_bn=False,
        activation=layers.LeakyReLU(0.2),
        use_bias=True,
        use_dropout=True,
        drop_value=0.3,
    )
    x = conv_block(
        x,
        512,
        kernel_size=(5, 5),
        strides=(2, 2),
        use_bn=False,
        activation=layers.LeakyReLU(0.2),
        use_bias=True,
        use_dropout=False,
        drop_value=0.3,
    )

    x = layers.Flatten()(x)
    x = layers.Dropout(0.2)(x)
    x = layers.Dense(1)(x)

    d_model = keras.models.Model(img_input, x, name="discriminator")
    return d_model


d_model = get_discriminator_model()
d_model.summary()
Model: "discriminator"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_1 (InputLayer)        [(None, 32, 32, 3)]       0         
                                                                 
 zero_padding2d (ZeroPadding  (None, 36, 36, 3)        0         
 2D)                                                             
                                                                 
 conv2d (Conv2D)             (None, 18, 18, 64)        4864      
                                                                 
 leaky_re_lu (LeakyReLU)     (None, 18, 18, 64)        0         
                                                                 
 conv2d_1 (Conv2D)           (None, 9, 9, 128)         204928    
                                                                 
 leaky_re_lu_1 (LeakyReLU)   (None, 9, 9, 128)         0         
                                                                 
 dropout (Dropout)           (None, 9, 9, 128)         0         
                                                                 
 conv2d_2 (Conv2D)           (None, 5, 5, 256)         819456    
                                                                 
 leaky_re_lu_2 (LeakyReLU)   (None, 5, 5, 256)         0         
                                                                 
 dropout_1 (Dropout)         (None, 5, 5, 256)         0         
                                                                 
 conv2d_3 (Conv2D)           (None, 3, 3, 512)         3277312   
                                                                 
 leaky_re_lu_3 (LeakyReLU)   (None, 3, 3, 512)         0         
                                                                 
 flatten (Flatten)           (None, 4608)              0         
                                                                 
 dropout_2 (Dropout)         (None, 4608)              0         
                                                                 
 dense (Dense)               (None, 1)                 4609      
                                                                 
=================================================================
Total params: 4,311,169
Trainable params: 4,311,169
Non-trainable params: 0
_________________________________________________________________
2023-02-02 08:14:05.549478: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:981] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2023-02-02 08:14:05.551202: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:981] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2023-02-02 08:14:05.552687: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:981] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2023-02-02 08:14:06.488119: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:981] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2023-02-02 08:14:06.488970: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:981] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2023-02-02 08:14:06.489673: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:981] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2023-02-02 08:14:06.490349: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1613] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 46568 MB memory:  -> device: 0, name: NVIDIA A40, pci bus id: 0000:85:00.0, compute capability: 8.6
In [9]:
def upsample_block(
    x,
    filters,
    activation,
    kernel_size=(3, 3),
    strides=(1, 1),
    up_size=(2, 2),
    padding="same",
    use_bn=False,
    use_bias=True,
    use_dropout=False,
    drop_value=0.3,
):
    x = layers.UpSampling2D(up_size)(x)
    x = layers.Conv2D(
        filters, kernel_size, strides=strides, padding=padding, use_bias=use_bias
    )(x)

    if use_bn:
        x = layers.BatchNormalization()(x)

    if activation:
        x = activation(x)
    if use_dropout:
        x = layers.Dropout(drop_value)(x)
    return x


def get_generator_model():
    noise = layers.Input(shape=(noise_dim,))
    x = layers.Dense(4 * 4 * 256, use_bias=False)(noise)
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU(0.2)(x)

    x = layers.Reshape((4, 4, 256))(x)
    x = upsample_block(
        x,
        128,
        layers.LeakyReLU(0.2),
        strides=(1, 1),
        use_bias=False,
        use_bn=True,
        padding="same",
        use_dropout=False,
    )
    x = upsample_block(
        x,
        64,
        layers.LeakyReLU(0.2),
        strides=(1, 1),
        use_bias=False,
        use_bn=True,
        padding="same",
        use_dropout=False,
    )
    x = upsample_block(
        x, 3, layers.Activation("tanh"), strides=(1, 1), use_bias=False, use_bn=True
    )
#     # At this point, we have an output which has the same shape as the input, (32, 32, 1).
#     # We will use a Cropping2D layer to make it (28, 28, 1).
#     x = layers.Cropping2D((2, 2))(x)

    g_model = keras.models.Model(noise, x, name="generator")
    return g_model


g_model = get_generator_model()
g_model.summary()
Model: "generator"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_2 (InputLayer)        [(None, 128)]             0         
                                                                 
 dense_1 (Dense)             (None, 4096)              524288    
                                                                 
 batch_normalization (BatchN  (None, 4096)             16384     
 ormalization)                                                   
                                                                 
 leaky_re_lu_4 (LeakyReLU)   (None, 4096)              0         
                                                                 
 reshape (Reshape)           (None, 4, 4, 256)         0         
                                                                 
 up_sampling2d (UpSampling2D  (None, 8, 8, 256)        0         
 )                                                               
                                                                 
 conv2d_4 (Conv2D)           (None, 8, 8, 128)         294912    
                                                                 
 batch_normalization_1 (Batc  (None, 8, 8, 128)        512       
 hNormalization)                                                 
                                                                 
 leaky_re_lu_5 (LeakyReLU)   (None, 8, 8, 128)         0         
                                                                 
 up_sampling2d_1 (UpSampling  (None, 16, 16, 128)      0         
 2D)                                                             
                                                                 
 conv2d_5 (Conv2D)           (None, 16, 16, 64)        73728     
                                                                 
 batch_normalization_2 (Batc  (None, 16, 16, 64)       256       
 hNormalization)                                                 
                                                                 
 leaky_re_lu_6 (LeakyReLU)   (None, 16, 16, 64)        0         
                                                                 
 up_sampling2d_2 (UpSampling  (None, 32, 32, 64)       0         
 2D)                                                             
                                                                 
 conv2d_6 (Conv2D)           (None, 32, 32, 3)         1728      
                                                                 
 batch_normalization_3 (Batc  (None, 32, 32, 3)        12        
 hNormalization)                                                 
                                                                 
 activation (Activation)     (None, 32, 32, 3)         0         
                                                                 
=================================================================
Total params: 911,820
Trainable params: 903,238
Non-trainable params: 8,582
_________________________________________________________________
In [ ]:
class WGAN(tf.keras.Model):
    def __init__(
        self,
        discriminator,
        generator,
        latent_dim,
        discriminator_extra_steps=3,
        gp_weight=10.0,
    ):
        super().__init__()
        self.discriminator = discriminator
        self.generator = generator
        self.latent_dim = latent_dim
        self.d_steps = discriminator_extra_steps
        self.gp_weight = gp_weight

    def compile(self, d_optimizer, g_optimizer, d_loss_fn, g_loss_fn):
        super().compile()
        self.d_optimizer = d_optimizer
        self.g_optimizer = g_optimizer
        self.d_loss_fn = d_loss_fn
        self.g_loss_fn = g_loss_fn

    def gradient_penalty(self, batch_size, real_images, fake_images):
        """Calculates the gradient penalty.

        This loss is calculated on an interpolated image
        and added to the discriminator loss.
        """
        # Get the interpolated image
        alpha = tf.random.normal([batch_size, 1, 1, 1], 0.0, 1.0)
        diff = fake_images - real_images
        interpolated = real_images + alpha * diff

        with tf.GradientTape() as gp_tape:
            gp_tape.watch(interpolated)
            # 1. Get the discriminator output for this interpolated image.
            pred = self.discriminator(interpolated, training=True)

        # 2. Calculate the gradients w.r.t to this interpolated image.
        grads = gp_tape.gradient(pred, [interpolated])[0]
        # 3. Calculate the norm of the gradients.
        norm = tf.sqrt(tf.reduce_sum(tf.square(grads), axis=[1, 2, 3]))
        gp = tf.reduce_mean((norm - 1.0) ** 2)
        return gp

    def train_step(self, real_images):
        if isinstance(real_images, tuple):
            real_images = real_images[0]

        # Get the batch size
        batch_size = tf.shape(real_images)[0]

        # 1. Train the generator and get the generator loss
        # 2. Train the discriminator and get the discriminator loss
        # 3. Calculate the gradient penalty
        # 4. Multiply this gradient penalty with a constant weight factor
        # 5. Add the gradient penalty to the discriminator loss
        # 6. Return the generator and discriminator losses as a loss dictionary

        # Train the discriminator first. The original paper recommends training
        # the discriminator for `x` more steps (typically 5) as compared to
        # one step of the generator. Here we will train it for 3 extra steps
        # as compared to 5 to reduce the training time.
        for i in range(self.d_steps):
            # Get the latent vector
            random_latent_vectors = tf.random.normal(
                shape=(batch_size, self.latent_dim)
            )
            with tf.GradientTape() as tape:
                # Generate fake images from the latent vector
                fake_images = self.generator(random_latent_vectors, training=True)
                # Get the logits for the fake images
                fake_logits = self.discriminator(fake_images, training=True)
                # Get the logits for the real images
                real_logits = self.discriminator(real_images, training=True)

                # Calculate the discriminator loss using the fake and real image logits
                d_cost = self.d_loss_fn(real_img=real_logits, fake_img=fake_logits)
                # Calculate the gradient penalty
                gp = self.gradient_penalty(batch_size, real_images, fake_images)
                # Add the gradient penalty to the original discriminator loss
                d_loss = d_cost + gp * self.gp_weight

            # Get the gradients w.r.t the discriminator loss
            d_gradient = tape.gradient(d_loss, self.discriminator.trainable_variables)
            # Update the weights of the discriminator using the discriminator optimizer
            self.d_optimizer.apply_gradients(
                zip(d_gradient, self.discriminator.trainable_variables)
            )

        # Train the generator
        # Get the latent vector
        random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))
        with tf.GradientTape() as tape:
            # Generate fake images using the generator
            generated_images = self.generator(random_latent_vectors, training=True)
            # Get the discriminator logits for fake images
            gen_img_logits = self.discriminator(generated_images, training=True)
            # Calculate the generator loss
            g_loss = self.g_loss_fn(gen_img_logits)

        # Get the gradients w.r.t the generator loss
        gen_gradient = tape.gradient(g_loss, self.generator.trainable_variables)
        # Update the weights of the generator using the generator optimizer
        self.g_optimizer.apply_gradients(
            zip(gen_gradient, self.generator.trainable_variables)
        )
        return {"d_loss": d_loss, "g_loss": g_loss}
In [ ]:
class GANMonitor(keras.callbacks.Callback):
    def __init__(self, num_img=6, latent_dim=128):
        self.num_img = num_img
        self.latent_dim = latent_dim

    def on_epoch_end(self, epoch, logs=None):
        random_latent_vectors = tf.random.normal(shape=(self.num_img, self.latent_dim))
        generated_images = self.model.generator(random_latent_vectors)
        generated_images = (generated_images * 127.5) + 127.5

        for i in range(self.num_img):
            img = generated_images[i].numpy()
            img = keras.preprocessing.image.array_to_img(img)
            img.save("generated_img_{i}_{epoch}.png".format(i=i, epoch=epoch))
In [27]:
# Instantiate the optimizer for both networks
# (learning_rate=0.0002, beta_1=0.5 are recommended)
generator_optimizer = keras.optimizers.Adam(
    learning_rate=0.0002, beta_1=0.5, beta_2=0.9
)
discriminator_optimizer = keras.optimizers.Adam(
    learning_rate=0.0002, beta_1=0.5, beta_2=0.9
)

# Define the loss functions for the discriminator,
# which should be (fake_loss - real_loss).
# We will add the gradient penalty later to this loss function.
def discriminator_loss(real_img, fake_img):
    real_loss = tf.reduce_mean(real_img)
    fake_loss = tf.reduce_mean(fake_img)
    return fake_loss - real_loss


# Define the loss functions for the generator.
def generator_loss(fake_img):
    return -tf.reduce_mean(fake_img)



# Instantiate the customer `GANMonitor` Keras callback.
cbk = GANMonitor(num_img=3, latent_dim=noise_dim)

# Get the wgan model
wgan = WGAN(
    discriminator=d_model,
    generator=g_model,
    latent_dim=noise_dim,
    discriminator_extra_steps=3,
)

# Compile the wgan model
wgan.compile(
    d_optimizer=discriminator_optimizer,
    g_optimizer=generator_optimizer,
    g_loss_fn=generator_loss,
    d_loss_fn=discriminator_loss,
)

# Start training
wgan.fit(train_images, batch_size=BATCH_SIZE, epochs=200, callbacks=[cbk])
Epoch 1/200
2023-02-02 09:39:05.069476: E tensorflow/core/grappler/optimizers/meta_optimizer.cc:954] layout failed: INVALID_ARGUMENT: Size of values 0 does not match size of permutation 4 @ fanin shape indiscriminator/dropout/dropout_1/SelectV2-2-TransposeNHWCToNCHW-LayoutOptimizer
118/118 [==============================] - 32s 182ms/step - d_loss: -1.0401 - g_loss: -0.9515
Epoch 2/200
118/118 [==============================] - 22s 183ms/step - d_loss: -1.0964 - g_loss: -0.3057
Epoch 3/200
118/118 [==============================] - 22s 182ms/step - d_loss: -1.1542 - g_loss: 0.2998
Epoch 4/200
118/118 [==============================] - 22s 183ms/step - d_loss: -1.1486 - g_loss: 1.2698
Epoch 5/200
118/118 [==============================] - 22s 183ms/step - d_loss: -1.1390 - g_loss: -1.0493
Epoch 6/200
118/118 [==============================] - 22s 184ms/step - d_loss: -1.0449 - g_loss: 0.8757
Epoch 7/200
118/118 [==============================] - 22s 185ms/step - d_loss: -1.1025 - g_loss: -0.4341
Epoch 8/200
118/118 [==============================] - 22s 182ms/step - d_loss: -1.1386 - g_loss: -0.9943
Epoch 9/200
118/118 [==============================] - 22s 184ms/step - d_loss: -1.0583 - g_loss: -0.4558
Epoch 10/200
118/118 [==============================] - 22s 183ms/step - d_loss: -1.1025 - g_loss: -0.3207
Epoch 11/200
118/118 [==============================] - 22s 183ms/step - d_loss: -1.0393 - g_loss: 1.0345
Epoch 12/200
118/118 [==============================] - 22s 183ms/step - d_loss: -1.2176 - g_loss: 1.9755
Epoch 13/200
118/118 [==============================] - 22s 184ms/step - d_loss: -0.9568 - g_loss: -0.3262
Epoch 14/200
118/118 [==============================] - 22s 183ms/step - d_loss: -1.0361 - g_loss: -1.3493
Epoch 15/200
118/118 [==============================] - 22s 182ms/step - d_loss: -1.0667 - g_loss: 0.3637
Epoch 16/200
118/118 [==============================] - 22s 183ms/step - d_loss: -1.1126 - g_loss: 0.1226
Epoch 17/200
118/118 [==============================] - 22s 184ms/step - d_loss: -1.0540 - g_loss: 0.9403
Epoch 18/200
118/118 [==============================] - 22s 183ms/step - d_loss: -1.1360 - g_loss: 0.3331
Epoch 19/200
118/118 [==============================] - 22s 183ms/step - d_loss: -0.9929 - g_loss: 0.4072
Epoch 20/200
118/118 [==============================] - 22s 183ms/step - d_loss: -1.0216 - g_loss: 0.0267
Epoch 21/200
118/118 [==============================] - 22s 182ms/step - d_loss: -1.0504 - g_loss: 0.2469
Epoch 22/200
118/118 [==============================] - 22s 182ms/step - d_loss: -1.0552 - g_loss: 1.1193
Epoch 23/200
118/118 [==============================] - 22s 183ms/step - d_loss: -1.0645 - g_loss: -0.6569
Epoch 24/200
118/118 [==============================] - 22s 183ms/step - d_loss: -1.1058 - g_loss: 0.7232
Epoch 25/200
118/118 [==============================] - 22s 183ms/step - d_loss: -1.0048 - g_loss: 0.2357
Epoch 26/200
118/118 [==============================] - 22s 183ms/step - d_loss: -1.0513 - g_loss: 1.5711
Epoch 27/200
118/118 [==============================] - 22s 183ms/step - d_loss: -1.0856 - g_loss: -1.6032
Epoch 28/200
118/118 [==============================] - 22s 184ms/step - d_loss: -0.9976 - g_loss: -2.2542
Epoch 29/200
118/118 [==============================] - 22s 182ms/step - d_loss: -1.1039 - g_loss: 1.8658
Epoch 30/200
118/118 [==============================] - 22s 183ms/step - d_loss: -1.1149 - g_loss: 0.7666
Epoch 31/200
118/118 [==============================] - 22s 182ms/step - d_loss: -1.0424 - g_loss: -1.0199
Epoch 32/200
118/118 [==============================] - 22s 182ms/step - d_loss: -1.0665 - g_loss: 1.7943
Epoch 33/200
118/118 [==============================] - 22s 183ms/step - d_loss: -1.1162 - g_loss: -1.8583
Epoch 34/200
118/118 [==============================] - 22s 183ms/step - d_loss: -1.0389 - g_loss: -1.9026
Epoch 35/200
118/118 [==============================] - 22s 183ms/step - d_loss: -1.0090 - g_loss: -1.0905
Epoch 36/200
118/118 [==============================] - 22s 184ms/step - d_loss: -1.0586 - g_loss: -1.2688
Epoch 37/200
118/118 [==============================] - 22s 183ms/step - d_loss: -1.0698 - g_loss: 0.6948
Epoch 38/200
118/118 [==============================] - 22s 184ms/step - d_loss: -1.0507 - g_loss: -1.3311
Epoch 39/200
118/118 [==============================] - 22s 184ms/step - d_loss: -1.0529 - g_loss: 0.4101
Epoch 40/200
118/118 [==============================] - 22s 184ms/step - d_loss: -1.0604 - g_loss: -0.8697
Epoch 41/200
118/118 [==============================] - 22s 184ms/step - d_loss: -1.0125 - g_loss: -0.7380
Epoch 42/200
118/118 [==============================] - 22s 183ms/step - d_loss: -0.9713 - g_loss: 0.0688
Epoch 43/200
118/118 [==============================] - 22s 184ms/step - d_loss: -0.9505 - g_loss: -2.1671
Epoch 44/200
118/118 [==============================] - 22s 183ms/step - d_loss: -1.0726 - g_loss: -1.8150
Epoch 45/200
118/118 [==============================] - 22s 183ms/step - d_loss: -1.0451 - g_loss: -0.8659
Epoch 46/200
118/118 [==============================] - 22s 183ms/step - d_loss: -1.0050 - g_loss: 0.0902
Epoch 47/200
118/118 [==============================] - 22s 183ms/step - d_loss: -1.0698 - g_loss: 0.5287
Epoch 48/200
118/118 [==============================] - 22s 184ms/step - d_loss: -1.0910 - g_loss: -0.2050
Epoch 49/200
118/118 [==============================] - 22s 183ms/step - d_loss: -1.0510 - g_loss: 2.9583
Epoch 50/200
118/118 [==============================] - 22s 183ms/step - d_loss: -1.0028 - g_loss: -0.3945
Epoch 51/200
118/118 [==============================] - 21s 182ms/step - d_loss: -1.0245 - g_loss: -1.6015
Epoch 52/200
118/118 [==============================] - 22s 183ms/step - d_loss: -1.0687 - g_loss: -0.9000
Epoch 53/200
118/118 [==============================] - 22s 183ms/step - d_loss: -0.9544 - g_loss: -4.1124
Epoch 54/200
118/118 [==============================] - 22s 183ms/step - d_loss: -0.9798 - g_loss: -1.4708
Epoch 55/200
118/118 [==============================] - 22s 182ms/step - d_loss: -1.0737 - g_loss: -1.7330
Epoch 56/200
118/118 [==============================] - 22s 183ms/step - d_loss: -0.9426 - g_loss: -2.0518
Epoch 57/200
118/118 [==============================] - 22s 183ms/step - d_loss: -0.9734 - g_loss: 1.7406
Epoch 58/200
118/118 [==============================] - 22s 184ms/step - d_loss: -1.0128 - g_loss: -1.1506
Epoch 59/200
118/118 [==============================] - 22s 183ms/step - d_loss: -1.0068 - g_loss: -1.1105
Epoch 60/200
118/118 [==============================] - 22s 183ms/step - d_loss: -0.9884 - g_loss: -0.0310
Epoch 61/200
118/118 [==============================] - 22s 182ms/step - d_loss: -1.0069 - g_loss: -0.0804
Epoch 62/200
118/118 [==============================] - 22s 182ms/step - d_loss: -0.9697 - g_loss: -1.1458
Epoch 63/200
118/118 [==============================] - 22s 183ms/step - d_loss: -0.9419 - g_loss: -0.9006
Epoch 64/200
118/118 [==============================] - 21s 182ms/step - d_loss: -1.0417 - g_loss: -1.7282
Epoch 65/200
118/118 [==============================] - 22s 182ms/step - d_loss: -0.9317 - g_loss: -0.1046
Epoch 66/200
118/118 [==============================] - 22s 182ms/step - d_loss: -0.9435 - g_loss: -0.2653
Epoch 67/200
118/118 [==============================] - 22s 182ms/step - d_loss: -0.8881 - g_loss: 0.6992
Epoch 68/200
118/118 [==============================] - 22s 183ms/step - d_loss: -0.9699 - g_loss: -0.2974
Epoch 69/200
118/118 [==============================] - 22s 183ms/step - d_loss: -0.9467 - g_loss: -2.5310
Epoch 70/200
118/118 [==============================] - 22s 183ms/step - d_loss: -0.9223 - g_loss: -1.6679
Epoch 71/200
118/118 [==============================] - 22s 183ms/step - d_loss: -0.9712 - g_loss: -0.0133
Epoch 72/200
118/118 [==============================] - 22s 183ms/step - d_loss: -0.9890 - g_loss: -0.0631
Epoch 73/200
118/118 [==============================] - 22s 184ms/step - d_loss: -1.0273 - g_loss: 1.5172
Epoch 74/200
118/118 [==============================] - 22s 183ms/step - d_loss: -0.9407 - g_loss: -0.8719
Epoch 75/200
118/118 [==============================] - 22s 183ms/step - d_loss: -0.9416 - g_loss: -1.5982
Epoch 76/200
118/118 [==============================] - 22s 183ms/step - d_loss: -1.0549 - g_loss: -1.3059
Epoch 77/200
118/118 [==============================] - 22s 183ms/step - d_loss: -0.9962 - g_loss: -0.6555
Epoch 78/200
118/118 [==============================] - 22s 183ms/step - d_loss: -0.9592 - g_loss: -0.4954
Epoch 79/200
118/118 [==============================] - 22s 184ms/step - d_loss: -0.8692 - g_loss: -0.0554
Epoch 80/200
118/118 [==============================] - 22s 182ms/step - d_loss: -0.9923 - g_loss: 1.0494
Epoch 81/200
118/118 [==============================] - 22s 183ms/step - d_loss: -0.9534 - g_loss: 0.8318
Epoch 82/200
118/118 [==============================] - 22s 183ms/step - d_loss: -1.0674 - g_loss: -1.2399
Epoch 83/200
118/118 [==============================] - 22s 183ms/step - d_loss: -0.9344 - g_loss: -0.9420
Epoch 84/200
118/118 [==============================] - 22s 184ms/step - d_loss: -0.9212 - g_loss: -1.7512
Epoch 85/200
118/118 [==============================] - 22s 183ms/step - d_loss: -0.9752 - g_loss: -1.0821
Epoch 86/200
118/118 [==============================] - 22s 183ms/step - d_loss: -0.9065 - g_loss: -0.6431
Epoch 87/200
118/118 [==============================] - 22s 183ms/step - d_loss: -0.9598 - g_loss: -1.8881
Epoch 88/200
118/118 [==============================] - 21s 182ms/step - d_loss: -0.9809 - g_loss: -0.4641
Epoch 89/200
118/118 [==============================] - 22s 183ms/step - d_loss: -0.8728 - g_loss: -0.2860
Epoch 90/200
118/118 [==============================] - 22s 182ms/step - d_loss: -0.9279 - g_loss: 0.1486
Epoch 91/200
118/118 [==============================] - 22s 182ms/step - d_loss: -1.0488 - g_loss: -0.0494
Epoch 92/200
118/118 [==============================] - 22s 183ms/step - d_loss: -0.9919 - g_loss: -2.5631
Epoch 93/200
118/118 [==============================] - 22s 182ms/step - d_loss: -1.0214 - g_loss: -0.0847
Epoch 94/200
118/118 [==============================] - 22s 183ms/step - d_loss: -0.9715 - g_loss: -0.6475
Epoch 95/200
118/118 [==============================] - 22s 183ms/step - d_loss: -0.9023 - g_loss: 1.6780
Epoch 96/200
118/118 [==============================] - 22s 183ms/step - d_loss: -0.8645 - g_loss: -0.9717
Epoch 97/200
118/118 [==============================] - 22s 183ms/step - d_loss: -0.9092 - g_loss: -0.0488
Epoch 98/200
118/118 [==============================] - 22s 183ms/step - d_loss: -0.9666 - g_loss: -2.2301
Epoch 99/200
118/118 [==============================] - 22s 183ms/step - d_loss: -1.0063 - g_loss: 0.6875
Epoch 100/200
118/118 [==============================] - 22s 183ms/step - d_loss: -0.9618 - g_loss: -0.0585
Epoch 101/200
118/118 [==============================] - 22s 183ms/step - d_loss: -0.9258 - g_loss: 1.6849
Epoch 102/200
118/118 [==============================] - 22s 183ms/step - d_loss: -0.9336 - g_loss: -0.7886
Epoch 103/200
118/118 [==============================] - 22s 183ms/step - d_loss: -0.8862 - g_loss: -0.3469
Epoch 104/200
118/118 [==============================] - 22s 184ms/step - d_loss: -0.9664 - g_loss: -0.9488
Epoch 105/200
118/118 [==============================] - 22s 182ms/step - d_loss: -0.8888 - g_loss: 0.5290
Epoch 106/200
118/118 [==============================] - 22s 183ms/step - d_loss: -1.0122 - g_loss: -1.1686
Epoch 107/200
118/118 [==============================] - 22s 184ms/step - d_loss: -0.9042 - g_loss: -0.9016
Epoch 108/200
118/118 [==============================] - 22s 183ms/step - d_loss: -0.9339 - g_loss: 1.4489
Epoch 109/200
118/118 [==============================] - 22s 183ms/step - d_loss: -0.9232 - g_loss: -0.2516
Epoch 110/200
118/118 [==============================] - 22s 183ms/step - d_loss: -0.9492 - g_loss: -0.5228
Epoch 111/200
118/118 [==============================] - 22s 183ms/step - d_loss: -0.8364 - g_loss: -2.3753
Epoch 112/200
118/118 [==============================] - 21s 182ms/step - d_loss: -0.9107 - g_loss: -1.5424
Epoch 113/200
118/118 [==============================] - 22s 182ms/step - d_loss: -0.9267 - g_loss: -0.2958
Epoch 114/200
118/118 [==============================] - 22s 183ms/step - d_loss: -0.8950 - g_loss: 0.3609
Epoch 115/200
118/118 [==============================] - 22s 182ms/step - d_loss: -0.8422 - g_loss: -1.6643
Epoch 116/200
118/118 [==============================] - 22s 183ms/step - d_loss: -0.8297 - g_loss: -0.1097
Epoch 117/200
118/118 [==============================] - 22s 183ms/step - d_loss: -0.9561 - g_loss: -0.1289
Epoch 118/200
118/118 [==============================] - 22s 183ms/step - d_loss: -0.9153 - g_loss: 0.2936
Epoch 119/200
118/118 [==============================] - 22s 183ms/step - d_loss: -0.8651 - g_loss: 2.5090
Epoch 120/200
118/118 [==============================] - 22s 183ms/step - d_loss: -0.8956 - g_loss: -0.5883
Epoch 121/200
118/118 [==============================] - 22s 183ms/step - d_loss: -0.8685 - g_loss: -1.8754
Epoch 122/200
118/118 [==============================] - 22s 183ms/step - d_loss: -0.9004 - g_loss: -2.3465
Epoch 123/200
118/118 [==============================] - 22s 183ms/step - d_loss: -0.8730 - g_loss: 1.0858
Epoch 124/200
118/118 [==============================] - 22s 183ms/step - d_loss: -0.9201 - g_loss: -0.2223
Epoch 125/200
118/118 [==============================] - 22s 183ms/step - d_loss: -0.8455 - g_loss: -0.9490
Epoch 126/200
118/118 [==============================] - 22s 184ms/step - d_loss: -0.8523 - g_loss: -3.4812
Epoch 127/200
118/118 [==============================] - 22s 183ms/step - d_loss: -0.8456 - g_loss: -1.9522
Epoch 128/200
118/118 [==============================] - 22s 183ms/step - d_loss: -0.8990 - g_loss: -2.8476
Epoch 129/200
118/118 [==============================] - 22s 184ms/step - d_loss: -0.8560 - g_loss: -4.0780
Epoch 130/200
118/118 [==============================] - 22s 183ms/step - d_loss: -0.9107 - g_loss: -0.7019
Epoch 131/200
118/118 [==============================] - 22s 183ms/step - d_loss: -0.8772 - g_loss: 0.2276
Epoch 132/200
118/118 [==============================] - 22s 183ms/step - d_loss: -0.8622 - g_loss: -3.1653
Epoch 133/200
118/118 [==============================] - 22s 183ms/step - d_loss: -0.8685 - g_loss: -1.3595
Epoch 134/200
118/118 [==============================] - 22s 183ms/step - d_loss: -0.8494 - g_loss: -1.9143
Epoch 135/200
118/118 [==============================] - 22s 183ms/step - d_loss: -0.8293 - g_loss: -0.7158
Epoch 136/200
118/118 [==============================] - 22s 183ms/step - d_loss: -0.8335 - g_loss: 0.6058
Epoch 137/200
118/118 [==============================] - 22s 183ms/step - d_loss: -0.9022 - g_loss: 0.6262
Epoch 138/200
118/118 [==============================] - 22s 183ms/step - d_loss: -0.8683 - g_loss: -2.6317
Epoch 139/200
118/118 [==============================] - 22s 183ms/step - d_loss: -0.8090 - g_loss: -0.3048
Epoch 140/200
118/118 [==============================] - 22s 182ms/step - d_loss: -0.8666 - g_loss: -0.2409
Epoch 141/200
118/118 [==============================] - 22s 182ms/step - d_loss: -0.8541 - g_loss: -1.3026
Epoch 142/200
118/118 [==============================] - 22s 184ms/step - d_loss: -0.8441 - g_loss: -0.5505
Epoch 143/200
118/118 [==============================] - 22s 184ms/step - d_loss: -0.8790 - g_loss: -0.8716
Epoch 144/200
118/118 [==============================] - 22s 184ms/step - d_loss: -0.8655 - g_loss: -2.3255
Epoch 145/200
118/118 [==============================] - 22s 183ms/step - d_loss: -0.9079 - g_loss: -1.6966
Epoch 146/200
118/118 [==============================] - 22s 182ms/step - d_loss: -0.8026 - g_loss: -0.7970
Epoch 147/200
118/118 [==============================] - 22s 182ms/step - d_loss: -0.8835 - g_loss: -3.0788
Epoch 148/200
118/118 [==============================] - 22s 182ms/step - d_loss: -0.8297 - g_loss: -2.7591
Epoch 149/200
118/118 [==============================] - 22s 182ms/step - d_loss: -0.8210 - g_loss: -0.5819
Epoch 150/200
118/118 [==============================] - 22s 183ms/step - d_loss: -0.8080 - g_loss: -2.0905
Epoch 151/200
118/118 [==============================] - 22s 182ms/step - d_loss: -0.9032 - g_loss: -3.1223
Epoch 152/200
118/118 [==============================] - 22s 183ms/step - d_loss: -0.9257 - g_loss: -1.2520
Epoch 153/200
118/118 [==============================] - 22s 183ms/step - d_loss: -0.8430 - g_loss: -0.2367
Epoch 154/200
118/118 [==============================] - 22s 183ms/step - d_loss: -0.8502 - g_loss: -2.8022
Epoch 155/200
118/118 [==============================] - 22s 184ms/step - d_loss: -0.8287 - g_loss: -3.2176
Epoch 156/200
118/118 [==============================] - 22s 182ms/step - d_loss: -0.8217 - g_loss: -3.1021
Epoch 157/200
118/118 [==============================] - 22s 182ms/step - d_loss: -0.8231 - g_loss: -0.8938
Epoch 158/200
118/118 [==============================] - 22s 182ms/step - d_loss: -0.8912 - g_loss: -1.5253
Epoch 159/200
118/118 [==============================] - 22s 182ms/step - d_loss: -0.8573 - g_loss: -0.2047
Epoch 160/200
118/118 [==============================] - 22s 184ms/step - d_loss: -0.7986 - g_loss: -1.7747
Epoch 161/200
118/118 [==============================] - 22s 182ms/step - d_loss: -0.8765 - g_loss: -1.4138
Epoch 162/200
118/118 [==============================] - 22s 183ms/step - d_loss: -0.8370 - g_loss: -1.5492
Epoch 163/200
118/118 [==============================] - 22s 182ms/step - d_loss: -0.8575 - g_loss: -1.1505
Epoch 164/200
118/118 [==============================] - 22s 182ms/step - d_loss: -0.7432 - g_loss: -0.5456
Epoch 165/200
118/118 [==============================] - 22s 183ms/step - d_loss: -0.8571 - g_loss: -2.6627
Epoch 166/200
118/118 [==============================] - 22s 183ms/step - d_loss: -0.8342 - g_loss: -0.0472
Epoch 167/200
118/118 [==============================] - 22s 183ms/step - d_loss: -0.8618 - g_loss: -0.7509
Epoch 168/200
118/118 [==============================] - 22s 182ms/step - d_loss: -0.7820 - g_loss: -1.1393
Epoch 169/200
118/118 [==============================] - 21s 182ms/step - d_loss: -0.8745 - g_loss: -0.8424
Epoch 170/200
118/118 [==============================] - 22s 183ms/step - d_loss: -0.7781 - g_loss: -0.6736
Epoch 171/200
118/118 [==============================] - 22s 183ms/step - d_loss: -0.8093 - g_loss: -3.3313
Epoch 172/200
118/118 [==============================] - 22s 183ms/step - d_loss: -0.8629 - g_loss: -3.6358
Epoch 173/200
118/118 [==============================] - 22s 183ms/step - d_loss: -0.8341 - g_loss: -3.6231
Epoch 174/200
118/118 [==============================] - 22s 182ms/step - d_loss: -0.8098 - g_loss: -1.4983
Epoch 175/200
118/118 [==============================] - 22s 184ms/step - d_loss: -0.8205 - g_loss: -1.3333
Epoch 176/200
118/118 [==============================] - 22s 184ms/step - d_loss: -0.7815 - g_loss: -1.0834
Epoch 177/200
118/118 [==============================] - 22s 183ms/step - d_loss: -0.8457 - g_loss: -1.7522
Epoch 178/200
118/118 [==============================] - 22s 183ms/step - d_loss: -0.7860 - g_loss: -0.6035
Epoch 179/200
118/118 [==============================] - 22s 183ms/step - d_loss: -0.7935 - g_loss: -0.9891
Epoch 180/200
118/118 [==============================] - 22s 184ms/step - d_loss: -0.8327 - g_loss: -2.5563
Epoch 181/200
118/118 [==============================] - 22s 184ms/step - d_loss: -0.8270 - g_loss: -2.4587
Epoch 182/200
118/118 [==============================] - 22s 184ms/step - d_loss: -0.8587 - g_loss: -2.0459
Epoch 183/200
118/118 [==============================] - 22s 184ms/step - d_loss: -0.8482 - g_loss: -2.1143
Epoch 184/200
118/118 [==============================] - 22s 183ms/step - d_loss: -0.7832 - g_loss: -3.2805
Epoch 185/200
118/118 [==============================] - 22s 183ms/step - d_loss: -0.8298 - g_loss: -2.4422
Epoch 186/200
118/118 [==============================] - 22s 184ms/step - d_loss: -0.8083 - g_loss: -2.6274
Epoch 187/200
118/118 [==============================] - 22s 184ms/step - d_loss: -0.8781 - g_loss: -1.5630
Epoch 188/200
118/118 [==============================] - 22s 184ms/step - d_loss: -0.7979 - g_loss: -0.8324
Epoch 189/200
118/118 [==============================] - 22s 182ms/step - d_loss: -0.8426 - g_loss: -1.9164
Epoch 190/200
118/118 [==============================] - 22s 183ms/step - d_loss: -0.8022 - g_loss: -0.4767
Epoch 191/200
118/118 [==============================] - 22s 182ms/step - d_loss: -0.8049 - g_loss: -1.6669
Epoch 192/200
118/118 [==============================] - 22s 183ms/step - d_loss: -0.8014 - g_loss: -2.7349
Epoch 193/200
118/118 [==============================] - 22s 183ms/step - d_loss: -0.8417 - g_loss: -1.9107
Epoch 194/200
118/118 [==============================] - 22s 183ms/step - d_loss: -0.8538 - g_loss: -1.7004
Epoch 195/200
118/118 [==============================] - 22s 183ms/step - d_loss: -0.8374 - g_loss: -2.4114
Epoch 196/200
118/118 [==============================] - 22s 183ms/step - d_loss: -0.8284 - g_loss: -1.4625
Epoch 197/200
118/118 [==============================] - 22s 183ms/step - d_loss: -0.8386 - g_loss: -1.1710
Epoch 198/200
118/118 [==============================] - 22s 183ms/step - d_loss: -0.8059 - g_loss: -1.6091
Epoch 199/200
118/118 [==============================] - 22s 183ms/step - d_loss: -0.8036 - g_loss: -0.7230
Epoch 200/200
118/118 [==============================] - 22s 183ms/step - d_loss: -0.8198 - g_loss: -2.4709
Out[27]:
<keras.callbacks.History at 0x7f0e18707970>
In [28]:
#save model
gen=wgan.generator
gen.save('WGAN.h5')
WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.

Plot images

In [27]:
gen=tf.keras.models.load_model('WGAN.h5',custom_objects={'gaussian_weight_init':gaussian_weight_init})
WARNING:tensorflow:No training configuration found in the save file, so the model was *not* compiled. Compile it manually.
In [57]:
random_latent_vectors = tf.random.normal(shape=(25, 100))
imgs=gen.predict(random_latent_vectors)
# scale from [-1,1] to [0,1]
imgs = (imgs + 1) / 2.0
for i in range(5 * 5):
        # define subplot
        plt.subplot(5, 5, 1 + i)
        # turn off axis
        plt.axis('off')
        # plot raw pixel data
        plt.imshow(imgs[i])
plt.savefig('WGAN.png')
plt.show()
1/1 [==============================] - 0s 13ms/step

Calculate FID

Previously in training we were calculating FID on the 10K test set of images. This is quite efficient as it is fast to compute but it is not the most accurate way to calculate FID. We shall now calculate on the full set of 50K+10K=60K images since it is what we trained on too.

In [30]:
from torchmetrics.image.fid import FrechetInceptionDistance
import torch
fid=FrechetInceptionDistance(reset_real_features=False).to('cuda')
fid.reset()
real_images = loadRealSamples()
real_images = real_images.astype(np.float32)
real_images=torch.from_numpy(np.array(real_images))

#change dtypes, scale back first
real_images=(real_images+1)*(255/2)
real_images=real_images.type(torch.uint8)
#from 50,000x32x32x3 to 50,000x3x32x32
real_images = real_images.permute(0, 3, 1, 2)
real_images = real_images.to('cuda')
#COMPUTE FID

# create a dataloader with batch size 128
dataloader = torch.utils.data.DataLoader(real_images, batch_size=128, shuffle=True,
                                          drop_last=True
                                          )


# update the FID with real images
for batch,images in enumerate(dataloader):
    print(f'Updating with real images now...Batch{batch}',end='\r')
    fid.update(images, real=True)

print(f'\nFinished! Ready to take in generated image features')
Updating with real images now...Batch467
Finished! Ready to take in generated image features
In [31]:
def computeFID(generator,latentDim):
    fid.reset()
    #PREPARE DATA
    random_latent_vectors = tf.random.normal(shape=(60000, 128))
    gen_images=generator.predict(random_latent_vectors)
#     gen_images=np.reshape(gen_images,(50000,3,32,32))
    gen_images = gen_images.astype(np.float32)
    gen_images = torch.from_numpy(np.array(gen_images))
    gen_images=(gen_images+1)*(255/2)
    gen_images=gen_images.type(torch.uint8)
    gen_images = gen_images.permute(0, 3, 1, 2)
    gen_images = gen_images.to('cuda')

    # create a dataloader with batch size 128
    dataloader_gen = torch.utils.data.DataLoader(gen_images, batch_size=128, shuffle=True,
                                            # pin_memory=True,
                                             drop_last=True)
    
    print('')
    # update the FID with generated images
    for batch,images in enumerate(dataloader_gen):
        print(f'Updating FID with generated images: Batch{batch}',end='\r')
        fid.update(images, real=False)
    
    print('\n')
    score = fid.compute()
    print(f'FID: {score.item()}')
    return score
In [33]:
computeFID(gen,128)
print('')
Updating FID with generated images: Batch467

FID: 35.99958419799805

Summary WGAN

This is our 2nd best model, with more tuning this can definitely be improved a lot but for the sake of time we will stop here. I already had to tune quite a lot to get the model to converge at all so I am quite happy with this result

Hyperparameter Tuning Model 3

Our best model so far in terms of FID was our DCGAN with 32 FID. Let's tune the following parameters to try to improve it

  • Discriminator Learning Rate (dLR)
  • Generator Learning Rate (gLR)
  • Kernel Size of Generator Conv layers (k)
    For the sake of time we will run 2 stages of tuning, one for kernel size, and another one for the 2 learning rates. We do the 2 learning rates together because they have a close relationship so its better to test in combinations
In [14]:
def createImageGen():
  datagen = tf.keras.preprocessing.image.ImageDataGenerator(
      rotation_range=15,
      width_shift_range=0.15,
      height_shift_range=0.15,
      zoom_range=0.2,
      horizontal_flip=True
      )
  return datagen

#function to reverse scaling/preprocessing steps, for data visualisation purposes
def reverseProcess(X):
    X = (X *127.5)+127.5
    X = X.astype('uint8')
    return X
    
In [15]:
datagen=createImageGen()

FID

In [16]:
from torchmetrics.image.fid import FrechetInceptionDistance
import torch
fid=FrechetInceptionDistance(reset_real_features=False).to('cuda')
fid.reset()
real_images = loadTestSamples()
real_images = real_images.astype(np.float32)
real_images=torch.from_numpy(np.array(real_images))

#change dtypes, scale back first
real_images=(real_images+1)*(255/2)
real_images=real_images.type(torch.uint8)
#from 50,000x32x32x3 to 50,000x3x32x32
real_images = real_images.permute(0, 3, 1, 2)
real_images = real_images.to('cuda')
#COMPUTE FID

# create a dataloader with batch size 128
dataloader = torch.utils.data.DataLoader(real_images, batch_size=128, shuffle=True,
                                          drop_last=True
                                          )

# update the FID with real images
for batch,images in enumerate(dataloader):
    print(f'Updating with real images now...Batch{batch}',end='\r')
    fid.update(images, real=True)

print(f'\nFinished! Ready to take in generated image features')

def computeFID(generator,latentDim):
    fid.reset()
    #PREPARE DATA
    gen_images = generateFakeSamples(generator,latentDim,10000)[0]
#     gen_images=np.reshape(gen_images,(50000,3,32,32))
    gen_images = gen_images.astype(np.float32)
    gen_images = torch.from_numpy(np.array(gen_images))
    gen_images=(gen_images+1)*(255/2)
    gen_images=gen_images.type(torch.uint8)
    gen_images = gen_images.permute(0, 3, 1, 2)
    gen_images = gen_images.to('cuda')

    # create a dataloader with batch size 128
    dataloader_gen = torch.utils.data.DataLoader(gen_images, batch_size=128, shuffle=True,
                                            # pin_memory=True,
                                             drop_last=True)
    
    print('')
    # update the FID with generated images
    for batch,images in enumerate(dataloader_gen):
        print(f'Updating FID with generated images: Batch{batch}',end='\r')
        fid.update(images, real=False)
    
    print('\n')
    score = fid.compute()
    print(f'FID: {score.item()}')
    return score
Updating with real images now...Batch77
Finished! Ready to take in generated image features

Model architecture left unchanged from model 3

  • Allow change of kernel size in generator
In [17]:
# define the standalone discriminator model
def defineDisc(in_shape=(32, 32, 3)):
    model = tf.keras.models.Sequential()
    # normal
    model.add(tf.keras.layers.Conv2D(128, (3, 3), padding='same', input_shape=in_shape,kernel_initializer=gaussian_weight_init))
    model.add(tf.keras.layers.LeakyReLU(alpha=0.2))
    # downsample
    model.add(tf.keras.layers.Conv2D(128, (3, 3), strides=(2, 2), padding='same',kernel_initializer=gaussian_weight_init))
    model.add(tf.keras.layers.LeakyReLU(alpha=0.2))
    # downsample
    model.add(tf.keras.layers.Conv2D(128, (3, 3), strides=(2, 2), padding='same',kernel_initializer=gaussian_weight_init))
    model.add(tf.keras.layers.LeakyReLU(alpha=0.2))
    # downsample
    model.add(tf.keras.layers.Conv2D(128, (3, 3), strides=(2, 2), padding='same',kernel_initializer=gaussian_weight_init))
    model.add(tf.keras.layers.LeakyReLU(alpha=0.2))
    # classifier
    model.add(tf.keras.layers.Flatten())
    model.add(tf.keras.layers.Dropout(0.4))
    model.add(tf.keras.layers.Dense(1, activation='sigmoid'))
    # compile model
    opt = tf.keras.optimizers.Adam(lr=0.0002, beta_1=0.5)
    model.compile(loss='binary_crossentropy', optimizer=opt, metrics=['accuracy'])
    return model

# define the standalone generator model
def defineGen(latent_dim,k):
    model = tf.keras.models.Sequential()
    # foundation for 4x4 image
    n_nodes = 256 * 4 * 4
    model.add(tf.keras.layers.Dense(n_nodes, input_dim=latent_dim))
    model.add(tf.keras.layers.BatchNormalization(momentum=0.9))
    model.add(tf.keras.layers.LeakyReLU(alpha=0.2))
    model.add(tf.keras.layers.Reshape((4, 4, 256)))
    # upsample to 8x8
    model.add(tf.keras.layers.Conv2DTranspose(128, kernel_size=k, strides=(2, 2), padding='same',kernel_initializer=gaussian_weight_init))
    model.add(tf.keras.layers.BatchNormalization(momentum=0.9))
    model.add(tf.keras.layers.LeakyReLU(alpha=0.2))
    # upsample to 16x16
    model.add(tf.keras.layers.Conv2DTranspose(128,  kernel_size=k, strides=(2, 2), padding='same',kernel_initializer=gaussian_weight_init))
    model.add(tf.keras.layers.BatchNormalization(momentum=0.9))
    model.add(tf.keras.layers.LeakyReLU(alpha=0.2))
    # upsample to 32x32
    model.add(tf.keras.layers.Conv2DTranspose(64,  kernel_size=k, strides=(2, 2), padding='same',kernel_initializer=gaussian_weight_init))
    model.add(tf.keras.layers.BatchNormalization(momentum=0.9))
    model.add(tf.keras.layers.LeakyReLU(alpha=0.2))
    # output layer
    model.add(tf.keras.layers.Conv2D(3, (3, 3), activation='tanh', padding='same'))
    return model

Learning rate scheduler

Now we have a dynamic scheduler for both discriminator and generator which simply takes one initial learning rate and scales them down depending on epoch count

In [18]:
#Learning rate schedulers
def scheduler(epoch,LR):
    if epoch < 30:
        return LR
    elif epoch < 60:
        return LR*0.8
    elif epoch < 120:
        return LR*0.65
    else:
        return LR*0.55
In [19]:
# define the combined generator and discriminator model, for updating the generator
def defineGAN(genModel, discModel):
    # make weights in the discriminator not trainable
    discModel.trainable = False
    # connect them
    model = tf.keras.models.Sequential()
    # add generator
    model.add(genModel)
    # add the discriminator
    model.add(discModel)
    # compile model
    opt = tf.keras.optimizers.Adam(lr=0.0002,beta_1=0.5)
    model.compile(loss='binary_crossentropy', optimizer=opt)
    return model

# train the generator and discriminator
def trainGAN(genModel, discModel, gan_model, dataset,latentDim, n_epochs=200, n_batch=128,fid_frequency=50,sumFrequency=2,name='model',best_FID=500,dLR=0.00015,gLR=0.0002):
    bat_per_epo = int(dataset.shape[0] / n_batch)
    half_batch = int(n_batch / 2)
    bestFID=best_FID #previous best FID
    iter=datagen.flow(dataset,y=None,batch_size=half_batch) #half batch size
    # manually enumerate epochs
    for i in range(n_epochs):
        #set learning rates via schedulers
        K.set_value(discModel.optimizer.learning_rate,scheduler(i,dLR))
        K.set_value(gan_model.optimizer.learning_rate,scheduler(i,gLR))
        
        
        
        # enumerate batches over the training set
        for j in range(bat_per_epo):
            # get randomly selected augmented samples
            X_real=iter.next()
            y_real = np.ones((len(X_real), 1)) #get 'real' labels
            #smoothing labels
            y_real=smooth_positive_labels(y_real)
            # update discriminator model weights
            d_loss1, _ = discModel.train_on_batch(X_real, y_real)
            # generate 'fake' examples
            X_fake, y_fake = generateFakeSamples(genModel, latentDim, half_batch)
            # update discriminator model weights
            d_loss2, _ = discModel.train_on_batch(X_fake, y_fake)
            # prepare points in latent space as input for the generator
            X_gan = generateLatentPoints(latentDim, n_batch)
            # create inverted labels for the fake samples
            y_gan = np.ones((n_batch, 1))
            # update the generator via the discriminator's error
            g_loss = gan_model.train_on_batch(X_gan, y_gan)
            
        
            #progress bar per epoch
            if j%25==0 or j==bat_per_epo-1:
                dLR_sf = round(discModel.optimizer.learning_rate.numpy() / 10 ** math.floor(math.log10(discModel.optimizer.learning_rate.numpy())),1) * 10 ** math.floor(math.log10(discModel.optimizer.learning_rate.numpy()))
                gLR_sf = round(gan_model.optimizer.learning_rate.numpy() / 10 ** math.floor(math.log10(gan_model.optimizer.learning_rate.numpy())),1) * 10 ** math.floor(math.log10(gan_model.optimizer.learning_rate.numpy()))
                if j==bat_per_epo-1: #if last step
                    print(f'>epoch{i+1}, {j+1}/{bat_per_epo}, dLR:{dLR_sf}, gLR:{gLR_sf}, d1={d_loss1:.3f}, d2={d_loss2:.3f}, g={g_loss:.3f}')
                else:
                    print(f'>epoch{i+1}, {j+1}/{bat_per_epo}, dLR:{dLR_sf}, gLR:{gLR_sf}, d1={d_loss1:.3f}, d2={d_loss2:.3f}, g={g_loss:.3f}', 
                          end='\r')
        #compute FID for every interval and last epoch
        if (i>0 and (i+1)%fid_frequency==0) or (i+1)==n_epochs:
            updated_gen_model = tf.keras.models.Sequential()
            updated_gen_model.add(gan_model.layers[0])
            # Set the weights of the new model to be the same as the generator portion of the gan_model
            updated_gen_model.set_weights(gan_model.layers[0].get_weights())
            FID=computeFID(updated_gen_model,latentDim)
            if FID<bestFID:
                print('saving best model')
                bestFID=FID
                #save model
                updated_gen_model = tf.keras.models.Sequential()
                updated_gen_model.add(gan_model.layers[0])
                # Set the weights of the new model to be the same as the generator portion of the gan_model
                updated_gen_model.set_weights(gan_model.layers[0].get_weights())
                updated_gen_model.save(f'models/{name}.h5')
                
        # evaluate the model performance every few epochs
        summarisePerf2(i, genModel, discModel, dataset, latentDim)
    return bestFID #return best FID score

Tuning Kernel Size

In [ ]:
%%time
results={} #dictionary
kList=[3,4,5]
# load image data
x_train = loadRealSamples()
# size of the latent space
latent_dim = 128

for k in kList:
    print(f'************ NOW TESTING: k={k}***************')
    bestFID=500 #reset FID
    # create the discriminator
    discModel = defineDisc()
    # create the generator
    genModel = defineGen(latent_dim,k)
    # create the gan
    gan_model = defineGAN(genModel, discModel)
    # train model
    bestFID=trainGAN(genModel, discModel, gan_model, x_train,
                latent_dim, n_epochs=125,n_batch=64,fid_frequency=10,sumFrequency=1,name=f'modelHPK{k}',best_FID=bestFID)
    results[k]=bestFID #add result to dictionary
    print(f'Best for k={k}: {bestFID}')
************ NOW TESTING: k=3***************
>epoch1, 937/937, dLR:0.00015000000000000001, gLR:0.0002, d1=0.484, d2=0.090, g=3.2357
>Accuracy real: 83%, fake: 100%
{3: tensor(58.5129, device='cuda:0'), 4: tensor(58.5129, device='cuda:0'), 5: tensor(58.5129, device='cuda:0')}

Summary

Best k was 3

Tuning D and G Learning Rates

It is better to tune these 2 together in combinations because they have a close relationship

In [ ]:
%%time
import os
#SELECT K FROM PREVIOUS TUNING
k=3
results={} #dictionary
#3 combinations
lrList=[(0.00015,0.00025),(0.0002,0.0002),(0.00025,0.00015)]
# load image data
x_train = loadRealSamples()
# size of the latent space
latent_dim = 128


for dLR,gLR in lrList:
    print(f'************ NOW TESTING: dLR={dLR},gLR={gLR}***************')
    bestFID=500
    # create the discriminator
    discModel = defineDisc()
    # create the generator
    genModel = defineGen(latent_dim,k)
    # create the gan
    gan_model = defineGAN(genModel, discModel)

    # train model
    bestFID=trainGAN(genModel, discModel, gan_model, x_train,
                latent_dim, n_epochs=125,n_batch=64,fid_frequency=10,sumFrequency=1,name=f'modelHP{dLR}D{gLR}G',best_FID=bestFID,
                     dLR=dLR,gLR=gLR)
    results[(dLR,gLR)]=bestFID.item() #add result to dictionary
    print(f'Best for dLR={dLR} and gLR={gLR}: {bestFID}')
In [21]:
print(results)
{(0.00015, 0.00025): 54.333892822265625, (0.0002, 0.0002): 57.90895462036133, (0.00025, 0.00015): 60.807586669921875}

Final FID test

Best parameters:

  • 0.00015 dLR
  • 0.00025 gLR
  • 3 kernel size
    let's now test FID on full 60K images set for most accurate FID
In [24]:
#load
bestModel=tf.keras.models.load_model('models/modelHP0.00015D0.00025G.h5',custom_objects={'gaussian_weight_init': gaussian_weight_init})
bestModel.summary()
WARNING:tensorflow:No training configuration found in the save file, so the model was *not* compiled. Compile it manually.
Model: "sequential_23"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
sequential_1 (Sequential)    (None, 32, 32, 3)         1064195   
=================================================================
Total params: 1,064,195
Trainable params: 1,055,363
Non-trainable params: 8,832
_________________________________________________________________
In [25]:
from torchmetrics.image.fid import FrechetInceptionDistance
import torch
fid=FrechetInceptionDistance(reset_real_features=False).to('cuda')
fid.reset()
real_images = loadRealSamples()
real_images = real_images.astype(np.float32)
real_images=torch.from_numpy(np.array(real_images))

#change dtypes, scale back first
real_images=(real_images+1)*(255/2)
real_images=real_images.type(torch.uint8)
#from 50,000x32x32x3 to 50,000x3x32x32
real_images = real_images.permute(0, 3, 1, 2)
real_images = real_images.to('cuda')
#COMPUTE FID

# create a dataloader with batch size 128
dataloader = torch.utils.data.DataLoader(real_images, batch_size=128, shuffle=True,
                                          drop_last=True
                                          )


# update the FID with real images
for batch,images in enumerate(dataloader):
    print(f'Updating with real images now...Batch{batch}',end='\r')
    fid.update(images, real=True)

print(f'\nFinished! Ready to take in generated image features')
Updating with real images now...Batch467
Finished! Ready to take in generated image features
In [26]:
def computeFID(generator,latentDim):
    fid.reset()
    #PREPARE DATA
    random_latent_vectors = tf.random.normal(shape=(60000, 128))
    gen_images=generator.predict(random_latent_vectors)
#     gen_images=np.reshape(gen_images,(50000,3,32,32))
    gen_images = gen_images.astype(np.float32)
    gen_images = torch.from_numpy(np.array(gen_images))
    gen_images=(gen_images+1)*(255/2)
    gen_images=gen_images.type(torch.uint8)
    gen_images = gen_images.permute(0, 3, 1, 2)
    gen_images = gen_images.to('cuda')

    # create a dataloader with batch size 128
    dataloader_gen = torch.utils.data.DataLoader(gen_images, batch_size=128, shuffle=True,
                                            # pin_memory=True,
                                             drop_last=True)
    
    print('')
    # update the FID with generated images
    for batch,images in enumerate(dataloader_gen):
        print(f'Updating FID with generated images: Batch{batch}',end='\r')
        fid.update(images, real=False)
    
    print('\n')
    score = fid.compute()
    print(f'FID: {score.item()}')
    return score
In [28]:
computeFID(bestModel,128)
print('')
Updating FID with generated images: Batch467

FID: 29.09366989135742

Conclusion

Our final tuned DCGAN model was able to achieve a FID of 29 which is quite decent. To compare to state-of-the-art benchmarks, this is 100th place. Best state-of-the-art models are <10 FID but those are much more complex models and many use the Pytorch StyleGAN library to achieve better implementations. Overall I am quite satisifed with my work. The final 1000 image generated plots can be found in finalImgs folder for the final DCGAN, CGAN and WGAN respectively.

Generate 1000 individual images

In [11]:
#load
bestModel=tf.keras.models.load_model('models/finalDCGAN.h5',custom_objects={'gaussian_weight_init': gaussian_weight_init})
bestModel.summary()
WARNING:tensorflow:No training configuration found in the save file, so the model was *not* compiled. Compile it manually.
Model: "sequential_106"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
sequential_76 (Sequential)   (None, 32, 32, 3)         1090499   
=================================================================
Total params: 1,090,499
Trainable params: 1,081,667
Non-trainable params: 8,832
_________________________________________________________________
In [126]:
imgs=generateFakeSamples(bestModel, 100, 1000)[0]
imgs=reverseProcess(imgs)
from PIL import Image

for i in range(1000):
    im = Image.fromarray(imgs[i]).convert('RGB')
    im.save(f"finalImgs/{i}.png")
In [32]:
imgs=generateFakeSamples(bestModel, 100, 1000)[0]
imgs=reverseProcess(imgs)
#plot some

for i in range(4 * 4):
        # define subplot
        plt.subplot(4, 4, 1 + i)
        # turn off axis
        plt.axis('off')
        # plot raw pixel data
        plt.imshow(imgs[i])
plt.show()
In [27]:
#Zip up whole folder
import shutil
shutil.make_archive('final', 'tar', '../ca2')
Out[27]:
'/home/p2112790/ca2/final.tar'